1//===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
11#include "mlir/Dialect/Shape/IR/Shape.h"
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
15#include "mlir/Dialect/CommonFolders.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/Dialect/Traits.h"
18#include "mlir/Dialect/UB/IR/UBOps.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/DialectImplementation.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/IR/TypeUtilities.h"
25#include "mlir/Interfaces/FunctionImplementation.h"
26#include "mlir/Transforms/InliningUtils.h"
27#include "llvm/ADT/SetOperations.h"
28#include "llvm/ADT/SmallString.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/Support/raw_ostream.h"
31
32using namespace mlir;
33using namespace mlir::shape;
34
35#include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
36
37namespace {
38#include "ShapeCanonicalization.inc"
39} // namespace
40
41RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
42 return RankedTensorType::get({rank}, IndexType::get(ctx));
43}
44
45bool shape::isExtentTensorType(Type type) {
46 auto ranked = llvm::dyn_cast<RankedTensorType>(type);
47 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
48}
49
50LogicalResult shape::getShapeVec(Value input,
51 SmallVectorImpl<int64_t> &shapeValues) {
52 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
53 auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
54 if (!type.hasRank())
55 return failure();
56 llvm::append_range(shapeValues, type.getShape());
57 return success();
58 }
59 DenseIntElementsAttr attr;
60 if (matchPattern(value: input, pattern: m_Constant(bind_value: &attr))) {
61 llvm::append_range(shapeValues, attr.getValues<int64_t>());
62 return success();
63 }
64 return failure();
65}
66
67static bool isErrorPropagationPossible(TypeRange operandTypes) {
68 return llvm::any_of(operandTypes,
69 llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
70}
71
72static LogicalResult verifySizeOrIndexOp(Operation *op) {
73 assert(op != nullptr && op->getNumResults() == 1);
74 Type resultTy = op->getResultTypes().front();
75 if (isErrorPropagationPossible(operandTypes: op->getOperandTypes())) {
76 if (!llvm::isa<SizeType>(resultTy))
77 return op->emitOpError()
78 << "if at least one of the operands can hold error values then "
79 "the result must be of type `size` to propagate them";
80 }
81 return success();
82}
83
84static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
85 assert(op != nullptr && op->getNumResults() == 1);
86 Type resultTy = op->getResultTypes().front();
87 if (isErrorPropagationPossible(operandTypes: op->getOperandTypes())) {
88 if (!llvm::isa<ShapeType>(resultTy))
89 return op->emitOpError()
90 << "if at least one of the operands can hold error values then "
91 "the result must be of type `shape` to propagate them";
92 }
93 return success();
94}
95
96template <typename... Ty>
97static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
98 return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
99}
100
101template <typename... Ty, typename... ranges>
102static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
103 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
104}
105
106//===----------------------------------------------------------------------===//
107// InlinerInterface
108//===----------------------------------------------------------------------===//
109
110namespace {
111/// This class defines the interface for inlining shape dialect ops.
112struct ShapeInlinerInterface : public DialectInlinerInterface {
113 using DialectInlinerInterface::DialectInlinerInterface;
114
115 // Returns true if the given region 'src' can be inlined into the region
116 // 'dest' that is attached to an operation registered to the current dialect.
117 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
118 IRMapping &) const final {
119 return true;
120 }
121
122 // Returns true if the given operation 'op', that is registered to this
123 // dialect, can be inlined into the region 'dest' that is attached to an
124 // operation registered to the current dialect.
125 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
126 IRMapping &) const final {
127 return true;
128 }
129};
130} // namespace
131
132void ShapeDialect::initialize() {
133 addOperations<
134#define GET_OP_LIST
135#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
136 >();
137 addTypes<
138#define GET_TYPEDEF_LIST
139#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
140 >();
141 addInterfaces<ShapeInlinerInterface>();
142 // Allow unknown operations during prototyping and testing. As the dialect is
143 // still evolving it makes it simple to start with an unregistered ops and
144 // try different variants before actually defining the op.
145 allowUnknownOperations();
146 declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp,
147 AssumingYieldOp>();
148}
149
150Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
151 Attribute value, Type type,
152 Location loc) {
153 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
154 return builder.create<ub::PoisonOp>(loc, type, poison);
155
156 if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
157 return builder.create<ConstShapeOp>(
158 loc, type, llvm::cast<DenseIntElementsAttr>(value));
159 if (llvm::isa<SizeType>(type))
160 return builder.create<ConstSizeOp>(loc, type,
161 llvm::cast<IntegerAttr>(value));
162 if (llvm::isa<WitnessType>(type))
163 return builder.create<ConstWitnessOp>(loc, type,
164 llvm::cast<BoolAttr>(value));
165
166 return arith::ConstantOp::materialize(builder, value, type, loc);
167}
168
169LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
170 NamedAttribute attribute) {
171 // Verify shape.lib attribute.
172 if (attribute.getName() == "shape.lib") {
173 if (!op->hasTrait<OpTrait::SymbolTable>())
174 return op->emitError(
175 "shape.lib attribute may only be on op implementing SymbolTable");
176
177 if (auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.getValue())) {
178 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
179 if (!symbol)
180 return op->emitError("shape function library ")
181 << symbolRef << " not found";
182 return isa<shape::FunctionLibraryOp>(symbol)
183 ? success()
184 : op->emitError()
185 << symbolRef << " required to be shape function library";
186 }
187
188 if (auto arr = llvm::dyn_cast<ArrayAttr>(attribute.getValue())) {
189 // Verify all entries are function libraries and mappings in libraries
190 // refer to unique ops.
191 DenseSet<StringAttr> key;
192 for (auto it : arr) {
193 if (!llvm::isa<SymbolRefAttr>(it))
194 return op->emitError(
195 "only SymbolRefAttr allowed in shape.lib attribute array");
196
197 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
198 SymbolTable::lookupSymbolIn(op, llvm::cast<SymbolRefAttr>(it)));
199 if (!shapeFnLib)
200 return op->emitError()
201 << it << " does not refer to FunctionLibraryOp";
202 for (auto mapping : shapeFnLib.getMapping()) {
203 if (!key.insert(mapping.getName()).second) {
204 return op->emitError("only one op to shape mapping allowed, found "
205 "multiple for `")
206 << mapping.getName() << "`";
207 }
208 }
209 }
210 return success();
211 }
212
213 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
214 "allowed as shape.lib attribute");
215 }
216 return success();
217}
218
219//===----------------------------------------------------------------------===//
220// AnyOp
221//===----------------------------------------------------------------------===//
222
223// TODO: Canonicalization should be implemented for shapes that can be
224// determined through mixtures of the known dimensions of the inputs.
225OpFoldResult AnyOp::fold(FoldAdaptor adaptor) {
226 // Only the last operand is checked because AnyOp is commutative.
227 if (adaptor.getInputs().back())
228 return adaptor.getInputs().back();
229
230 return nullptr;
231}
232
233//===----------------------------------------------------------------------===//
234// AssumingOp
235//===----------------------------------------------------------------------===//
236
237ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
238 result.regions.reserve(1);
239 Region *doRegion = result.addRegion();
240
241 auto &builder = parser.getBuilder();
242 OpAsmParser::UnresolvedOperand cond;
243 if (parser.parseOperand(cond) ||
244 parser.resolveOperand(cond, builder.getType<WitnessType>(),
245 result.operands))
246 return failure();
247
248 // Parse optional results type list.
249 if (parser.parseOptionalArrowTypeList(result.types))
250 return failure();
251
252 // Parse the region and add a terminator if elided.
253 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
254 return failure();
255 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
256
257 // Parse the optional attribute list.
258 if (parser.parseOptionalAttrDict(result.attributes))
259 return failure();
260 return success();
261}
262
263void AssumingOp::print(OpAsmPrinter &p) {
264 bool yieldsResults = !getResults().empty();
265
266 p << " " << getWitness();
267 if (yieldsResults)
268 p << " -> (" << getResultTypes() << ")";
269 p << ' ';
270 p.printRegion(getDoRegion(),
271 /*printEntryBlockArgs=*/false,
272 /*printBlockTerminators=*/yieldsResults);
273 p.printOptionalAttrDict((*this)->getAttrs());
274}
275
276namespace {
277// Removes AssumingOp with a passing witness and inlines the region.
278struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
279 using OpRewritePattern<AssumingOp>::OpRewritePattern;
280
281 LogicalResult matchAndRewrite(AssumingOp op,
282 PatternRewriter &rewriter) const override {
283 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
284 if (!witness || !witness.getPassingAttr())
285 return failure();
286
287 AssumingOp::inlineRegionIntoParent(op, rewriter);
288 return success();
289 }
290};
291
292struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
293 using OpRewritePattern<AssumingOp>::OpRewritePattern;
294
295 LogicalResult matchAndRewrite(AssumingOp op,
296 PatternRewriter &rewriter) const override {
297 Block *body = op.getBody();
298 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
299
300 // Find used values.
301 SmallVector<Value, 4> newYieldOperands;
302 for (auto [opResult, yieldOperand] :
303 llvm::zip(op.getResults(), yieldOp.getOperands())) {
304 if (!opResult.getUses().empty()) {
305 newYieldOperands.push_back(yieldOperand);
306 }
307 }
308
309 // Rewrite only if redundant results exist.
310 if (newYieldOperands.size() == yieldOp->getNumOperands())
311 return failure();
312
313 // Replace yield op in the old assuming op's body and move the entire region
314 // to the new assuming op.
315 rewriter.setInsertionPointToEnd(body);
316 auto newYieldOp =
317 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
318 rewriter.setInsertionPoint(op);
319 auto newOp = rewriter.create<AssumingOp>(
320 op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
321 newOp.getDoRegion().takeBody(op.getDoRegion());
322
323 // Use the new results to replace the previously used ones.
324 SmallVector<Value, 4> replacementValues;
325 auto src = newOp.getResults().begin();
326 for (auto it : op.getResults()) {
327 if (it.getUses().empty())
328 replacementValues.push_back(nullptr);
329 else
330 replacementValues.push_back(*src++);
331 }
332 rewriter.replaceOp(op, replacementValues);
333 return success();
334 }
335};
336} // namespace
337
338void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
339 MLIRContext *context) {
340 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
341}
342
343// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
344void AssumingOp::getSuccessorRegions(
345 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
346 // AssumingOp has unconditional control flow into the region and back to the
347 // parent, so return the correct RegionSuccessor purely based on the index
348 // being None or 0.
349 if (!point.isParent()) {
350 regions.push_back(RegionSuccessor(getResults()));
351 return;
352 }
353
354 regions.push_back(RegionSuccessor(&getDoRegion()));
355}
356
357void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
358 PatternRewriter &rewriter) {
359 auto *blockBeforeAssuming = rewriter.getInsertionBlock();
360 auto *assumingBlock = op.getBody();
361 auto initPosition = rewriter.getInsertionPoint();
362 auto *blockAfterAssuming =
363 rewriter.splitBlock(blockBeforeAssuming, initPosition);
364
365 // Remove the AssumingOp and AssumingYieldOp.
366 auto &yieldOp = assumingBlock->back();
367 rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
368 rewriter.replaceOp(op, yieldOp.getOperands());
369 rewriter.eraseOp(&yieldOp);
370
371 // Merge blocks together as there was no branching behavior from the
372 // AssumingOp.
373 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
374 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
375}
376
377void AssumingOp::build(
378 OpBuilder &builder, OperationState &result, Value witness,
379 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
380 OpBuilder::InsertionGuard g(builder);
381
382 result.addOperands(witness);
383 Region *bodyRegion = result.addRegion();
384 builder.createBlock(bodyRegion);
385
386 // Build body.
387 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
388 builder.create<AssumingYieldOp>(result.location, yieldValues);
389
390 SmallVector<Type, 2> assumingTypes;
391 for (Value v : yieldValues)
392 assumingTypes.push_back(v.getType());
393 result.addTypes(assumingTypes);
394}
395
396//===----------------------------------------------------------------------===//
397// AddOp
398//===----------------------------------------------------------------------===//
399
400LogicalResult mlir::shape::AddOp::inferReturnTypes(
401 MLIRContext *context, std::optional<Location> location,
402 AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
403 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
404 llvm::isa<SizeType>(adaptor.getRhs().getType()))
405 inferredReturnTypes.assign({SizeType::get(context)});
406 else
407 inferredReturnTypes.assign({IndexType::get(context)});
408 return success();
409}
410
411bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
412 // SizeType is compatible with IndexType.
413 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
414}
415
416OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
417 // add(x, 0) -> x
418 if (matchPattern(getRhs(), m_Zero()))
419 return getLhs();
420
421 return constFoldBinaryOp<IntegerAttr>(
422 adaptor.getOperands(),
423 [](APInt a, const APInt &b) { return std::move(a) + b; });
424}
425
426LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
427
428//===----------------------------------------------------------------------===//
429// AssumingAllOp
430//===----------------------------------------------------------------------===//
431
432namespace {
433
434// Merge multiple `shape.assuming_all` operations together.
435//
436// %0 = shape.assuming_all %w0, %w1
437// %1 = shape.assuming_all %w2, %0
438//
439// to:
440//
441// %0 = shape.assuming_all %w0, %w2, %w2
442struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
443 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
444
445 LogicalResult matchAndRewrite(AssumingAllOp op,
446 PatternRewriter &rewriter) const override {
447 SmallVector<Value> operands;
448
449 for (Value operand : op.getInputs()) {
450 if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
451 operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
452 else
453 operands.push_back(operand);
454 }
455
456 // We didn't find any other `assuming_all` ops to merge with.
457 if (operands.size() == op.getNumOperands())
458 return failure();
459
460 // Replace with a new `assuming_all` operation with merged constraints.
461 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
462 return success();
463 }
464};
465
466// Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
467// are subsumed by others.
468//
469// %0 = shape.cstr_broadcastable %shape0, %shape1
470// %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
471//
472// %2 = shape.cstr_broadcastable %shape3, %shape4
473// %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
474//
475// %4 = shape.assuming_all %0, %1, %2, %3
476//
477// to:
478//
479// %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
480// %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
481// %2 = shape.assuming_all %0, %1
482//
483// In this example if shapes [0, 1, 2] are broadcastable, then it means that
484// shapes [0, 1] are broadcastable too, and can be removed from the list of
485// constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
486// matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
487struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
488 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
489
490 LogicalResult matchAndRewrite(AssumingAllOp op,
491 PatternRewriter &rewriter) const override {
492 // Collect all `CstrBroadcastableOp` operands first.
493 SetVector<CstrBroadcastableOp> operands;
494 for (Value operand : op.getInputs()) {
495 // TODO: Apply this optimization if some of the witnesses are not
496 // produced by the `cstr_broadcastable`.
497 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
498 if (!broadcastable)
499 return failure();
500
501 operands.insert(broadcastable);
502 }
503
504 // Skip trivial `assuming_all` operations.
505 if (operands.size() <= 1)
506 return failure();
507
508 // Collect shapes checked by `cstr_broadcastable` operands.
509 SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
510 for (auto cstr : operands) {
511 DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
512 shapes.emplace_back(cstr, std::move(shapesSet));
513 }
514
515 // Sort by the number of shape operands (larger to smaller).
516 llvm::sort(shapes, [](auto a, auto b) {
517 return a.first.getNumOperands() > b.first.getNumOperands();
518 });
519
520 // We start from the `cst_broadcastable` operations with largest number of
521 // shape operands, and remove redundant `cst_broadcastable` operations. We
522 // do this until we find a set of `cst_broadcastable` operations with
523 // non-overlapping constraints.
524 SmallVector<CstrBroadcastableOp> markedForErase;
525
526 for (unsigned i = 0; i < shapes.size(); ++i) {
527 auto isSubset = [&](auto pair) {
528 return llvm::set_is_subset(pair.second, shapes[i].second);
529 };
530
531 // Keep redundant `cstr_broadcastable` operations to be erased.
532 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
533 for (auto *it0 = it; it0 < shapes.end(); ++it0)
534 markedForErase.push_back(it0->first);
535 shapes.erase(it, shapes.end());
536 }
537
538 // We didn't find any operands that could be removed.
539 if (markedForErase.empty())
540 return failure();
541
542 // Collect non-overlapping `cst_broadcastable` constraints.
543 SmallVector<Value> uniqueConstraints;
544 for (auto &shape : shapes)
545 uniqueConstraints.push_back(shape.first.getResult());
546
547 // Replace with a new `assuming_all` operation ...
548 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
549
550 // ... and maybe erase `cstr_broadcastable` ops without uses.
551 for (auto &op : markedForErase)
552 if (op->use_empty())
553 rewriter.eraseOp(op);
554
555 return success();
556 }
557};
558
559struct AssumingAllToCstrEqCanonicalization
560 : public OpRewritePattern<AssumingAllOp> {
561 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
562
563 LogicalResult matchAndRewrite(AssumingAllOp op,
564 PatternRewriter &rewriter) const override {
565 SmallVector<Value, 8> shapes;
566 for (Value w : op.getInputs()) {
567 auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
568 if (!cstrEqOp)
569 return failure();
570 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
571 return llvm::is_contained(shapes, s);
572 });
573 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
574 return failure();
575 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
576 }
577 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
578 return success();
579 }
580};
581
582template <typename OpTy>
583struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
584 using OpRewritePattern<OpTy>::OpRewritePattern;
585
586 LogicalResult matchAndRewrite(OpTy op,
587 PatternRewriter &rewriter) const override {
588 // Find unique operands.
589 SetVector<Value> unique(op.operand_begin(), op.operand_end());
590
591 // Reduce op to equivalent with unique operands.
592 if (unique.size() < op.getNumOperands()) {
593 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
594 unique.takeVector(), op->getAttrs());
595 return success();
596 }
597
598 return failure();
599 }
600};
601} // namespace
602
603void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
604 MLIRContext *context) {
605 patterns
606 .add<MergeAssumingAllOps, AssumingAllOneOp,
607 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
608 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
609}
610
611OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {
612 // Iterate in reverse to first handle all constant operands. They are
613 // guaranteed to be the tail of the inputs because this is commutative.
614 for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
615 Attribute a = adaptor.getInputs()[idx];
616 // Cannot fold if any inputs are not constant;
617 if (!a)
618 return nullptr;
619
620 // We do not need to keep statically known values after handling them in
621 // this method.
622 getOperation()->eraseOperand(idx);
623
624 // Always false if any input is statically known false
625 if (!llvm::cast<BoolAttr>(a).getValue())
626 return a;
627 }
628 // If this is reached, all inputs were statically known passing.
629 return BoolAttr::get(getContext(), true);
630}
631
632LogicalResult AssumingAllOp::verify() {
633 // Ensure that AssumingAllOp contains at least one operand
634 if (getNumOperands() == 0)
635 return emitOpError("no operands specified");
636
637 return success();
638}
639
640//===----------------------------------------------------------------------===//
641// BroadcastOp
642//===----------------------------------------------------------------------===//
643
644OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
645 if (getShapes().size() == 1) {
646 // Otherwise, we need a cast which would be a canonicalization, not folding.
647 if (getShapes().front().getType() != getType())
648 return nullptr;
649 return getShapes().front();
650 }
651
652 if (!adaptor.getShapes().front())
653 return nullptr;
654
655 SmallVector<int64_t, 6> resultShape(
656 llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
657 .getValues<int64_t>());
658
659 for (auto next : adaptor.getShapes().drop_front()) {
660 if (!next)
661 return nullptr;
662 auto nextShape = llvm::to_vector<6>(
663 llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
664
665 SmallVector<int64_t, 6> tmpShape;
666 // If the shapes are not compatible, we can't fold it.
667 // TODO: Fold to an "error".
668 if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
669 return nullptr;
670
671 resultShape.clear();
672 std::copy(tmpShape.begin(), tmpShape.end(),
673 std::back_inserter(resultShape));
674 }
675
676 Builder builder(getContext());
677 return builder.getIndexTensorAttr(resultShape);
678}
679
680LogicalResult BroadcastOp::verify() {
681 return verifyShapeOrExtentTensorOp(*this);
682}
683
684namespace {
685template <typename OpTy>
686struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
687 using OpRewritePattern<OpTy>::OpRewritePattern;
688
689 LogicalResult matchAndRewrite(OpTy op,
690 PatternRewriter &rewriter) const override {
691 auto isPotentiallyNonEmptyShape = [](Value shape) {
692 if (auto extentTensorTy =
693 llvm::dyn_cast<RankedTensorType>(shape.getType())) {
694 if (extentTensorTy.getDimSize(0) == 0)
695 return false;
696 }
697 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
698 if (constShape.getShape().empty())
699 return false;
700 }
701 return true;
702 };
703 auto newOperands = llvm::filter_to_vector<8>(op->getOperands(),
704 isPotentiallyNonEmptyShape);
705
706 // Replace the op with empty shape constant if all operants are reduced to
707 // be empty.
708 if (newOperands.empty()) {
709 rewriter.replaceOpWithNewOp<ConstShapeOp>(
710 op, op->getResultTypes().front(), rewriter.getIndexTensorAttr({}));
711 return success();
712 }
713
714 // Reduce op to equivalent without empty shape operands.
715 if (newOperands.size() < op.getNumOperands()) {
716 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
717 op->getAttrs());
718 return success();
719 }
720
721 return failure();
722 }
723};
724
725struct BroadcastForwardSingleOperandPattern
726 : public OpRewritePattern<BroadcastOp> {
727 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
728
729 LogicalResult matchAndRewrite(BroadcastOp op,
730 PatternRewriter &rewriter) const override {
731 if (op.getNumOperands() != 1)
732 return failure();
733 Value replacement = op.getShapes().front();
734
735 // Insert cast if needed.
736 if (replacement.getType() != op.getType()) {
737 auto loc = op.getLoc();
738 if (llvm::isa<ShapeType>(op.getType())) {
739 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
740 } else {
741 assert(!llvm::isa<ShapeType>(op.getType()) &&
742 !llvm::isa<ShapeType>(replacement.getType()) &&
743 "expect extent tensor cast");
744 replacement =
745 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
746 }
747 }
748
749 rewriter.replaceOp(op, replacement);
750 return success();
751 }
752};
753
754struct BroadcastFoldConstantOperandsPattern
755 : public OpRewritePattern<BroadcastOp> {
756 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
757
758 LogicalResult matchAndRewrite(BroadcastOp op,
759 PatternRewriter &rewriter) const override {
760 SmallVector<int64_t, 8> foldedConstantShape;
761 SmallVector<Value, 8> newShapeOperands;
762 for (Value shape : op.getShapes()) {
763 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
764 SmallVector<int64_t, 8> newFoldedConstantShape;
765 if (OpTrait::util::getBroadcastedShape(
766 foldedConstantShape,
767 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
768 newFoldedConstantShape)) {
769 foldedConstantShape = newFoldedConstantShape;
770 continue;
771 }
772 }
773 newShapeOperands.push_back(shape);
774 }
775
776 // Need at least two constant operands to fold anything.
777 if (op.getNumOperands() - newShapeOperands.size() < 2)
778 return failure();
779
780 auto foldedConstantOperandsTy = RankedTensorType::get(
781 {static_cast<int64_t>(foldedConstantShape.size())},
782 rewriter.getIndexType());
783 newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
784 op.getLoc(), foldedConstantOperandsTy,
785 rewriter.getIndexTensorAttr(foldedConstantShape)));
786 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
787 newShapeOperands);
788 return success();
789 }
790};
791
792template <typename OpTy>
793struct CanonicalizeCastExtentTensorOperandsPattern
794 : public OpRewritePattern<OpTy> {
795 using OpRewritePattern<OpTy>::OpRewritePattern;
796
797 LogicalResult matchAndRewrite(OpTy op,
798 PatternRewriter &rewriter) const override {
799 // Canonicalize operands.
800 bool anyChange = false;
801 auto canonicalizeOperand = [&](Value operand) -> Value {
802 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
803 // Only eliminate the cast if it holds no shape information.
804 bool isInformationLoosingCast =
805 llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
806 if (isInformationLoosingCast) {
807 anyChange = true;
808 return castOp.getSource();
809 }
810 }
811 return operand;
812 };
813 auto newOperands = llvm::to_vector<8>(
814 llvm::map_range(op.getOperands(), canonicalizeOperand));
815
816 // Rewrite op if any change required.
817 if (!anyChange)
818 return failure();
819 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
820 return success();
821 }
822};
823
824struct BroadcastConcretizeResultTypePattern
825 : public OpRewritePattern<BroadcastOp> {
826 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
827
828 LogicalResult matchAndRewrite(BroadcastOp op,
829 PatternRewriter &rewriter) const override {
830 // Only concretize dynamic extent tensor result types.
831 auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
832 if (!resultTy || !resultTy.isDynamicDim(0))
833 return failure();
834
835 // Infer resulting shape rank if possible.
836 int64_t maxRank = 0;
837 for (Value shape : op.getShapes()) {
838 if (auto extentTensorTy =
839 llvm::dyn_cast<RankedTensorType>(shape.getType())) {
840 // Cannot infer resulting shape rank if any operand is dynamically
841 // ranked.
842 if (extentTensorTy.isDynamicDim(0))
843 return failure();
844 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
845 }
846 }
847
848 auto newOp = rewriter.create<BroadcastOp>(
849 op.getLoc(), getExtentTensorType(getContext(), maxRank),
850 op.getShapes());
851 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
852 return success();
853 }
854};
855} // namespace
856
857void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
858 MLIRContext *context) {
859 patterns.add<BroadcastConcretizeResultTypePattern,
860 BroadcastFoldConstantOperandsPattern,
861 BroadcastForwardSingleOperandPattern,
862 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
863 RemoveDuplicateOperandsPattern<BroadcastOp>,
864 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
865}
866
867//===----------------------------------------------------------------------===//
868// ConcatOp
869//===----------------------------------------------------------------------===//
870
871OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
872 if (!adaptor.getLhs() || !adaptor.getRhs())
873 return nullptr;
874 auto lhsShape = llvm::to_vector<6>(
875 llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
876 auto rhsShape = llvm::to_vector<6>(
877 llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
878 SmallVector<int64_t, 6> resultShape;
879 resultShape.append(lhsShape.begin(), lhsShape.end());
880 resultShape.append(rhsShape.begin(), rhsShape.end());
881 Builder builder(getContext());
882 return builder.getIndexTensorAttr(resultShape);
883}
884
885//===----------------------------------------------------------------------===//
886// ConstShapeOp
887//===----------------------------------------------------------------------===//
888
889void ConstShapeOp::print(OpAsmPrinter &p) {
890 p << " ";
891 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
892 p << "[";
893 interleaveComma(getShape().getValues<int64_t>(), p);
894 p << "] : ";
895 p.printType(getType());
896}
897
898ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
899 if (parser.parseOptionalAttrDict(result.attributes))
900 return failure();
901 // We piggy-back on ArrayAttr parsing, though we don't internally store the
902 // shape as an ArrayAttr.
903 // TODO: Implement custom parser and maybe make syntax a bit more concise.
904 Attribute extentsRaw;
905 NamedAttrList dummy;
906 if (parser.parseAttribute(extentsRaw, "dummy", dummy))
907 return failure();
908 auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
909 if (!extentsArray)
910 return failure();
911 SmallVector<int64_t, 6> ints;
912 for (Attribute extent : extentsArray) {
913 IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
914 if (!attr)
915 return failure();
916 ints.push_back(attr.getInt());
917 }
918 Builder &builder = parser.getBuilder();
919 result.addAttribute("shape", builder.getIndexTensorAttr(ints));
920 Type resultTy;
921 if (parser.parseColonType(resultTy))
922 return failure();
923 result.types.push_back(resultTy);
924 return success();
925}
926
927OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); }
928
929void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
930 MLIRContext *context) {
931 patterns.add<TensorCastConstShape>(context);
932}
933
934LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
935 MLIRContext *context, std::optional<Location> location,
936 ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
937 Builder b(context);
938 const Properties prop = adaptor.getProperties();
939 inferredReturnTypes.assign({RankedTensorType::get(
940 {static_cast<int64_t>(prop.shape.size())}, b.getIndexType())});
941 return success();
942}
943
944bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
945 TypeRange r) {
946 if (l.size() != 1 || r.size() != 1)
947 return false;
948
949 Type lhs = l.front();
950 Type rhs = r.front();
951
952 if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
953 // Shape type is compatible with all other valid return types.
954 return true;
955 return lhs == rhs;
956}
957
958//===----------------------------------------------------------------------===//
959// CstrBroadcastableOp
960//===----------------------------------------------------------------------===//
961
962void CstrBroadcastableOp::getCanonicalizationPatterns(
963 RewritePatternSet &patterns, MLIRContext *context) {
964 // Canonicalization patterns have overlap with the considerations during
965 // folding in case additional shape information is inferred at some point that
966 // does not result in folding.
967 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
968 CstrBroadcastableEqOps,
969 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
970 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
971}
972
973// Return true if there is exactly one attribute not representing a scalar
974// broadcast.
975static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
976 bool nonScalarSeen = false;
977 for (Attribute a : attributes) {
978 if (!a || llvm::cast<DenseIntElementsAttr>(Val&: a).getNumElements() != 0) {
979 if (nonScalarSeen)
980 return false;
981 nonScalarSeen = true;
982 }
983 }
984 return true;
985}
986
987OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
988 // No broadcasting is needed if all operands but one are scalar.
989 if (hasAtMostSingleNonScalar(adaptor.getShapes()))
990 return BoolAttr::get(getContext(), true);
991
992 if ([&] {
993 SmallVector<SmallVector<int64_t, 6>, 6> extents;
994 for (const auto &operand : adaptor.getShapes()) {
995 if (!operand)
996 return false;
997 extents.push_back(llvm::to_vector<6>(
998 llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
999 }
1000 return OpTrait::util::staticallyKnownBroadcastable(extents);
1001 }())
1002 return BoolAttr::get(getContext(), true);
1003
1004 // Lastly, see if folding can be completed based on what constraints are known
1005 // on the input shapes.
1006 if ([&] {
1007 SmallVector<SmallVector<int64_t, 6>, 6> extents;
1008 for (auto shapeValue : getShapes()) {
1009 extents.emplace_back();
1010 if (failed(getShapeVec(shapeValue, extents.back())))
1011 return false;
1012 }
1013 return OpTrait::util::staticallyKnownBroadcastable(extents);
1014 }())
1015 return BoolAttr::get(getContext(), true);
1016
1017 // Because a failing witness result here represents an eventual assertion
1018 // failure, we do not replace it with a constant witness.
1019 return nullptr;
1020}
1021
1022LogicalResult CstrBroadcastableOp::verify() {
1023 // Ensure that CstrBroadcastableOp contains at least two operands
1024 if (getNumOperands() < 2)
1025 return emitOpError("required at least 2 input shapes");
1026 return success();
1027}
1028
1029//===----------------------------------------------------------------------===//
1030// CstrEqOp
1031//===----------------------------------------------------------------------===//
1032
1033void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1034 MLIRContext *context) {
1035 // If inputs are equal, return passing witness
1036 patterns.add<CstrEqEqOps>(context);
1037}
1038
1039OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) {
1040 if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) {
1041 return a && a == adaptor.getShapes().front();
1042 }))
1043 return BoolAttr::get(getContext(), true);
1044
1045 // Because a failing witness result here represents an eventual assertion
1046 // failure, we do not try to replace it with a constant witness. Similarly, we
1047 // cannot if there are any non-const inputs.
1048 return nullptr;
1049}
1050
1051//===----------------------------------------------------------------------===//
1052// ConstSizeOp
1053//===----------------------------------------------------------------------===//
1054
1055void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1056 int64_t value) {
1057 build(builder, result, builder.getIndexAttr(value));
1058}
1059
1060OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); }
1061
1062void ConstSizeOp::getAsmResultNames(
1063 llvm::function_ref<void(Value, StringRef)> setNameFn) {
1064 SmallString<4> buffer;
1065 llvm::raw_svector_ostream os(buffer);
1066 os << "c" << getValue();
1067 setNameFn(getResult(), os.str());
1068}
1069
1070//===----------------------------------------------------------------------===//
1071// ConstWitnessOp
1072//===----------------------------------------------------------------------===//
1073
1074OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); }
1075
1076//===----------------------------------------------------------------------===//
1077// CstrRequireOp
1078//===----------------------------------------------------------------------===//
1079
1080OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
1081 return adaptor.getPred();
1082}
1083
1084//===----------------------------------------------------------------------===//
1085// DimOp
1086//===----------------------------------------------------------------------===//
1087
1088std::optional<int64_t> DimOp::getConstantIndex() {
1089 if (auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>())
1090 return constSizeOp.getValue().getLimitedValue();
1091 if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
1092 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1093 return std::nullopt;
1094}
1095
1096OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1097 Type valType = getValue().getType();
1098 auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
1099 if (!valShapedType || !valShapedType.hasRank())
1100 return nullptr;
1101 std::optional<int64_t> index = getConstantIndex();
1102 if (!index.has_value())
1103 return nullptr;
1104 if (index.value() < 0 || index.value() >= valShapedType.getRank())
1105 return nullptr;
1106 auto extent = valShapedType.getDimSize(*index);
1107 if (ShapedType::isDynamic(extent))
1108 return nullptr;
1109 return IntegerAttr::get(IndexType::get(getContext()), extent);
1110}
1111
1112LogicalResult mlir::shape::DimOp::inferReturnTypes(
1113 MLIRContext *context, std::optional<Location> location,
1114 DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1115 inferredReturnTypes.assign({adaptor.getIndex().getType()});
1116 return success();
1117}
1118
1119bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1120 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1121}
1122
1123//===----------------------------------------------------------------------===//
1124// DivOp
1125//===----------------------------------------------------------------------===//
1126
1127OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
1128 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1129 if (!lhs)
1130 return nullptr;
1131 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1132 if (!rhs || rhs.getValue().isZero())
1133 return nullptr;
1134
1135 // Division in APInt does not follow floor(lhs, rhs) when the result is
1136 // negative. Rather, APInt rounds toward zero.
1137 APInt quotient, remainder;
1138 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1139 if (quotient.isNegative() && !remainder.isZero()) {
1140 quotient -= 1;
1141 }
1142
1143 Type indexTy = IndexType::get(getContext());
1144 return IntegerAttr::get(indexTy, quotient);
1145}
1146
1147LogicalResult mlir::shape::DivOp::inferReturnTypes(
1148 MLIRContext *context, std::optional<Location> location,
1149 DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1150 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1151 llvm::isa<SizeType>(adaptor.getRhs().getType()))
1152 inferredReturnTypes.assign({SizeType::get(context)});
1153 else
1154 inferredReturnTypes.assign({IndexType::get(context)});
1155 return success();
1156}
1157
1158bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1159 // SizeType is compatible with IndexType.
1160 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1161}
1162
1163LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1164
1165//===----------------------------------------------------------------------===//
1166// ShapeEqOp
1167//===----------------------------------------------------------------------===//
1168
1169OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) {
1170 bool allSame = true;
1171 if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
1172 return {};
1173 for (Attribute operand : adaptor.getShapes().drop_front()) {
1174 if (!operand)
1175 return {};
1176 allSame = allSame && operand == adaptor.getShapes().front();
1177 }
1178 return BoolAttr::get(getContext(), allSame);
1179}
1180
1181//===----------------------------------------------------------------------===//
1182// IndexToSizeOp
1183//===----------------------------------------------------------------------===//
1184
1185OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
1186 // Constant values of both types, `shape.size` and `index`, are represented as
1187 // `IntegerAttr`s which makes constant folding simple.
1188 if (Attribute arg = adaptor.getArg())
1189 return arg;
1190 return {};
1191}
1192
1193void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1194 MLIRContext *context) {
1195 patterns.add<SizeToIndexToSizeCanonicalization>(context);
1196}
1197
1198//===----------------------------------------------------------------------===//
1199// FromExtentsOp
1200//===----------------------------------------------------------------------===//
1201
1202OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
1203 if (llvm::any_of(adaptor.getExtents(), [](Attribute a) { return !a; }))
1204 return nullptr;
1205 SmallVector<int64_t, 6> extents;
1206 for (auto attr : adaptor.getExtents())
1207 extents.push_back(llvm::cast<IntegerAttr>(attr).getInt());
1208 Builder builder(getContext());
1209 return builder.getIndexTensorAttr(extents);
1210}
1211
1212//===----------------------------------------------------------------------===//
1213// FunctionLibraryOp
1214//===----------------------------------------------------------------------===//
1215
1216void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1217 StringRef name) {
1218 result.attributes.push_back(builder.getNamedAttr(
1219 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1220}
1221
1222FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1223 auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
1224 getMapping().get(op->getName().getIdentifier()));
1225 if (!attr)
1226 return nullptr;
1227 return lookupSymbol<FuncOp>(attr);
1228}
1229
1230ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1231 OperationState &result) {
1232 // Parse the op name.
1233 StringAttr nameAttr;
1234 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1235 result.attributes))
1236 return failure();
1237
1238 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1239 return failure();
1240
1241 auto *bodyRegion = result.addRegion();
1242 if (parser.parseRegion(*bodyRegion))
1243 return failure();
1244
1245 if (parser.parseKeyword("mapping"))
1246 return failure();
1247
1248 DictionaryAttr mappingAttr;
1249 if (parser.parseAttribute(mappingAttr,
1250 parser.getBuilder().getType<NoneType>(), "mapping",
1251 result.attributes))
1252 return failure();
1253 return success();
1254}
1255
1256void FunctionLibraryOp::print(OpAsmPrinter &p) {
1257 p << ' ';
1258 p.printSymbolName(getName());
1259 p.printOptionalAttrDictWithKeyword(
1260 (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1261 p << ' ';
1262 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1263 /*printBlockTerminators=*/false);
1264 p << " mapping ";
1265 p.printAttributeWithoutType(getMappingAttr());
1266}
1267
1268//===----------------------------------------------------------------------===//
1269// FuncOp
1270//===----------------------------------------------------------------------===//
1271
1272FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1273 ArrayRef<NamedAttribute> attrs) {
1274 OpBuilder builder(location->getContext());
1275 OperationState state(location, getOperationName());
1276 FuncOp::build(builder, state, name, type, attrs);
1277 return cast<FuncOp>(Operation::create(state));
1278}
1279FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1280 Operation::dialect_attr_range attrs) {
1281 SmallVector<NamedAttribute, 8> attrRef(attrs);
1282 return create(location, name, type, llvm::ArrayRef(attrRef));
1283}
1284FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1285 ArrayRef<NamedAttribute> attrs,
1286 ArrayRef<DictionaryAttr> argAttrs) {
1287 FuncOp func = create(location, name, type, attrs);
1288 func.setAllArgAttrs(argAttrs);
1289 return func;
1290}
1291
1292void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
1293 FunctionType type, ArrayRef<NamedAttribute> attrs,
1294 ArrayRef<DictionaryAttr> argAttrs) {
1295 state.addAttribute(FuncOp::getSymNameAttrName(state.name),
1296 builder.getStringAttr(name));
1297 state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
1298 TypeAttr::get(type));
1299 state.attributes.append(attrs.begin(), attrs.end());
1300 state.addRegion();
1301
1302 if (argAttrs.empty())
1303 return;
1304 assert(type.getNumInputs() == argAttrs.size());
1305 call_interface_impl::addArgAndResultAttrs(
1306 builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
1307 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
1308}
1309
1310ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1311 auto buildFuncType =
1312 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
1313 function_interface_impl::VariadicFlag,
1314 std::string &) { return builder.getFunctionType(argTypes, results); };
1315
1316 return function_interface_impl::parseFunctionOp(
1317 parser, result, /*allowVariadic=*/false,
1318 getFunctionTypeAttrName(result.name), buildFuncType,
1319 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1320}
1321
1322void FuncOp::print(OpAsmPrinter &p) {
1323 function_interface_impl::printFunctionOp(
1324 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
1325 getArgAttrsAttrName(), getResAttrsAttrName());
1326}
1327
1328//===----------------------------------------------------------------------===//
1329// GetExtentOp
1330//===----------------------------------------------------------------------===//
1331
1332std::optional<int64_t> GetExtentOp::getConstantDim() {
1333 if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1334 return constSizeOp.getValue().getLimitedValue();
1335 if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1336 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1337 return std::nullopt;
1338}
1339
1340OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
1341 auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1342 if (!elements)
1343 return nullptr;
1344 std::optional<int64_t> dim = getConstantDim();
1345 if (!dim.has_value())
1346 return nullptr;
1347 if (dim.value() >= elements.getNumElements())
1348 return nullptr;
1349 return elements.getValues<Attribute>()[(uint64_t)dim.value()];
1350}
1351
1352void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1353 int64_t dim) {
1354 auto loc = result.location;
1355 auto dimAttr = builder.getIndexAttr(dim);
1356 if (llvm::isa<ShapeType>(shape.getType())) {
1357 Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1358 build(builder, result, builder.getType<SizeType>(), shape, dim);
1359 } else {
1360 Value dim =
1361 builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1362 build(builder, result, builder.getIndexType(), shape, dim);
1363 }
1364}
1365
1366LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1367 MLIRContext *context, std::optional<Location> location,
1368 GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1369 inferredReturnTypes.assign({IndexType::get(context)});
1370 return success();
1371}
1372
1373bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1374 TypeRange r) {
1375 // SizeType is compatible with IndexType.
1376 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1377}
1378
1379LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
1380
1381//===----------------------------------------------------------------------===//
1382// IsBroadcastableOp
1383//===----------------------------------------------------------------------===//
1384
1385void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1386 MLIRContext *context) {
1387 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1388}
1389
1390OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
1391 // Can always broadcast fewer than two shapes.
1392 if (adaptor.getShapes().size() < 2) {
1393 return BoolAttr::get(getContext(), true);
1394 }
1395
1396 return nullptr;
1397}
1398
1399//===----------------------------------------------------------------------===//
1400// MeetOp
1401//===----------------------------------------------------------------------===//
1402
1403LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1404 MLIRContext *context, std::optional<Location> location,
1405 MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1406 if (adaptor.getOperands().empty())
1407 return failure();
1408
1409 auto isShapeType = [](Type arg) {
1410 if (llvm::isa<ShapeType>(arg))
1411 return true;
1412 return isExtentTensorType(arg);
1413 };
1414
1415 ValueRange::type_range types = adaptor.getOperands().getTypes();
1416 Type acc = types.front();
1417 for (auto t : drop_begin(types)) {
1418 Type l = acc, r = t;
1419 if (!llvm::isa<ShapeType, SizeType>(l))
1420 std::swap(l, r);
1421
1422 // Handle sizes, propagate error type if present.
1423 if (llvm::isa<SizeType>(l)) {
1424 if (llvm::isa<SizeType, IndexType>(r))
1425 acc = l;
1426 else
1427 return emitOptionalError(location, "requires all sizes or shapes");
1428 } else if (llvm::isa<IndexType>(l)) {
1429 if (llvm::isa<IndexType>(r))
1430 acc = r;
1431 else
1432 return emitOptionalError(location, "requires all sizes or shapes");
1433 } else if (llvm::isa<ShapeType>(l)) {
1434 // Handle shapes, propagate error type if present.
1435 if (isShapeType(r))
1436 acc = l;
1437 else
1438 return emitOptionalError(location, "requires all sizes or shapes");
1439 } else if (isExtentTensorType(l)) {
1440 auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
1441 auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
1442 if (ShapedType::isDynamic(rank1))
1443 acc = l;
1444 else if (ShapedType::isDynamic(rank2))
1445 acc = r;
1446 else if (rank1 != rank2)
1447 return emitOptionalError(location, "unequal shape cardinality");
1448 else
1449 acc = l;
1450 }
1451 }
1452 inferredReturnTypes.assign({acc});
1453 return success();
1454}
1455
1456bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1457 if (l.size() != 1 || r.size() != 1)
1458 return false;
1459 if (l == r)
1460 return true;
1461
1462 Type lhs = l.front();
1463 Type rhs = r.front();
1464
1465 if (!llvm::isa<ShapeType, SizeType>(lhs))
1466 std::swap(lhs, rhs);
1467
1468 if (llvm::isa<SizeType>(lhs))
1469 return llvm::isa<SizeType, IndexType>(rhs);
1470 if (llvm::isa<ShapeType>(lhs))
1471 return llvm::isa<ShapeType, TensorType>(rhs);
1472
1473 if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1474 return true;
1475 return false;
1476}
1477
1478//===----------------------------------------------------------------------===//
1479// RankOp
1480//===----------------------------------------------------------------------===//
1481
1482OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
1483 auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1484 if (!shape)
1485 return {};
1486 int64_t rank = shape.getNumElements();
1487 Builder builder(getContext());
1488 return builder.getIndexAttr(rank);
1489}
1490
1491/// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1492/// Constant folding fails in cases where only the rank is constant, not the
1493/// shape itself.
1494/// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1495///
1496/// Example:
1497///
1498/// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1499/// %rank = shape.rank %shape
1500///
1501/// becomes
1502///
1503/// %rank = shape.const_size 3
1504
1505namespace {
1506struct RankShapeOfCanonicalizationPattern
1507 : public OpRewritePattern<shape::RankOp> {
1508 using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1509
1510 LogicalResult matchAndRewrite(shape::RankOp op,
1511 PatternRewriter &rewriter) const override {
1512 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1513 if (!shapeOfOp)
1514 return failure();
1515 auto rankedTensorType =
1516 llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1517 if (!rankedTensorType)
1518 return failure();
1519 int64_t rank = rankedTensorType.getRank();
1520 if (llvm::isa<IndexType>(op.getType())) {
1521 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1522 rank);
1523 } else if (llvm::isa<shape::SizeType>(op.getType())) {
1524 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1525 } else {
1526 return failure();
1527 }
1528 return success();
1529 }
1530};
1531} // namespace
1532
1533void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1534 MLIRContext *context) {
1535 patterns.add<RankShapeOfCanonicalizationPattern>(context);
1536}
1537
1538LogicalResult mlir::shape::RankOp::inferReturnTypes(
1539 MLIRContext *context, std::optional<Location> location,
1540 RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1541 if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1542 inferredReturnTypes.assign({SizeType::get(context)});
1543 else
1544 inferredReturnTypes.assign({IndexType::get(context)});
1545 return success();
1546}
1547
1548bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1549 // SizeType is compatible with IndexType.
1550 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1551}
1552
1553LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
1554
1555//===----------------------------------------------------------------------===//
1556// NumElementsOp
1557//===----------------------------------------------------------------------===//
1558
1559OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
1560
1561 // Fold only when argument constant.
1562 Attribute shape = adaptor.getShape();
1563 if (!shape)
1564 return {};
1565
1566 APInt product(64, 1);
1567 for (auto value : llvm::cast<DenseIntElementsAttr>(shape))
1568 product *= value;
1569 Builder builder(getContext());
1570 return builder.getIndexAttr(product.getLimitedValue());
1571}
1572
1573LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1574 MLIRContext *context, std::optional<Location> location,
1575 NumElementsOp::Adaptor adaptor,
1576 SmallVectorImpl<Type> &inferredReturnTypes) {
1577 if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1578 inferredReturnTypes.assign({SizeType::get(context)});
1579 else
1580 inferredReturnTypes.assign({IndexType::get(context)});
1581 return success();
1582}
1583
1584bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1585 TypeRange r) {
1586 // SizeType is compatible with IndexType.
1587 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1588}
1589
1590LogicalResult shape::NumElementsOp::verify() {
1591 return verifySizeOrIndexOp(*this);
1592}
1593
1594//===----------------------------------------------------------------------===//
1595// MaxOp
1596//===----------------------------------------------------------------------===//
1597
1598OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
1599 // If operands are equal, just propagate one.
1600 if (getLhs() == getRhs())
1601 return getLhs();
1602 return nullptr;
1603}
1604
1605LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1606 MLIRContext *context, std::optional<Location> location,
1607 MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1608 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1609 inferredReturnTypes.assign({adaptor.getLhs().getType()});
1610 else
1611 inferredReturnTypes.assign({SizeType::get(context)});
1612 return success();
1613}
1614
1615bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1616 if (l.size() != 1 || r.size() != 1)
1617 return false;
1618 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1619 return true;
1620 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1621 return true;
1622 return false;
1623}
1624
1625//===----------------------------------------------------------------------===//
1626// MinOp
1627//===----------------------------------------------------------------------===//
1628
1629OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
1630 // If operands are equal, just propagate one.
1631 if (getLhs() == getRhs())
1632 return getLhs();
1633 return nullptr;
1634}
1635
1636LogicalResult mlir::shape::MinOp::inferReturnTypes(
1637 MLIRContext *context, std::optional<Location> location,
1638 MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1639 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1640 inferredReturnTypes.assign({adaptor.getLhs().getType()});
1641 else
1642 inferredReturnTypes.assign({SizeType::get(context)});
1643 return success();
1644}
1645
1646bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1647 if (l.size() != 1 || r.size() != 1)
1648 return false;
1649 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1650 return true;
1651 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1652 return true;
1653 return false;
1654}
1655
1656//===----------------------------------------------------------------------===//
1657// MulOp
1658//===----------------------------------------------------------------------===//
1659
1660OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1661 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1662 if (!lhs)
1663 return nullptr;
1664 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1665 if (!rhs)
1666 return nullptr;
1667 APInt folded = lhs.getValue() * rhs.getValue();
1668 Type indexTy = IndexType::get(getContext());
1669 return IntegerAttr::get(indexTy, folded);
1670}
1671
1672LogicalResult mlir::shape::MulOp::inferReturnTypes(
1673 MLIRContext *context, std::optional<Location> location,
1674 MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1675 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1676 llvm::isa<SizeType>(adaptor.getRhs().getType()))
1677 inferredReturnTypes.assign({SizeType::get(context)});
1678 else
1679 inferredReturnTypes.assign({IndexType::get(context)});
1680 return success();
1681}
1682
1683bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1684 // SizeType is compatible with IndexType.
1685 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1686}
1687
1688LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
1689
1690//===----------------------------------------------------------------------===//
1691// ShapeOfOp
1692//===----------------------------------------------------------------------===//
1693
1694namespace {
1695/// Replace shape_of(x) where x has a constant shape with a const_shape op.
1696struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
1697 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1698
1699 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1700 PatternRewriter &rewriter) const override {
1701 auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
1702 if (!type || !type.hasStaticShape())
1703 return failure();
1704 Location loc = op.getLoc();
1705 Value constShape =
1706 rewriter
1707 .create<ConstShapeOp>(loc,
1708 rewriter.getIndexTensorAttr(type.getShape()))
1709 .getResult();
1710 if (constShape.getType() != op.getResult().getType())
1711 constShape = rewriter.create<tensor::CastOp>(
1712 loc, op.getResult().getType(), constShape);
1713 rewriter.replaceOp(op, constShape);
1714 return success();
1715 }
1716};
1717
1718// Canonicalize
1719//
1720// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1721// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
1722//
1723// to
1724//
1725// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1726// %1 = %shape
1727//
1728struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
1729 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1730
1731 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1732 PatternRewriter &rewriter) const override {
1733 auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
1734 if (!tensorReshapeOp)
1735 return rewriter.notifyMatchFailure(op, "producer is not tensor.reshape");
1736 if (!isa<TensorType>(op.getType()))
1737 return rewriter.notifyMatchFailure(op, "result is not a tensor");
1738
1739 // Operand 'shape' of 'tensor.reshape' may now be used as the result of
1740 // 'shape.shape_of'. While its type is guaranteed to be compatible in well-
1741 // formed IR, it may not be identical (dynamically vs statically shaped),
1742 // in which case it needs to be cast first using 'tensor.cast'.
1743 // Additionally, it may not have identical element type (i32 vs index)
1744 // while it has identical shaped type (dynamic vs static), in which case it
1745 // needs to be cast first using 'arith.index_cast'. Note: 'shape.shape_of'
1746 // op result must be shape or extent tensor.
1747 Value shape = tensorReshapeOp.getShape();
1748
1749 auto opTensorTy = cast<RankedTensorType>(op.getType());
1750 auto shapeTensorTy = cast<RankedTensorType>(shape.getType());
1751
1752 if (opTensorTy != shapeTensorTy) {
1753 if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
1754 shape = rewriter.create<tensor::CastOp>(op.getLoc(), opTensorTy, shape);
1755 else if (!isExtentTensorType(shapeTensorTy))
1756 shape =
1757 rewriter.create<arith::IndexCastOp>(op.getLoc(), opTensorTy, shape);
1758 }
1759
1760 rewriter.replaceOp(op, shape);
1761 return success();
1762 }
1763};
1764
1765// Canonicalize
1766// ```
1767// %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1768// %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1769// ```
1770// to
1771// ```
1772// %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1773// ```
1774struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1775 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1776
1777 LogicalResult matchAndRewrite(tensor::CastOp op,
1778 PatternRewriter &rewriter) const override {
1779 auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
1780 if (!ty || ty.getRank() != 1)
1781 return failure();
1782
1783 auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1784 if (!shapeOfOp)
1785 return failure();
1786
1787 // Argument type must be ranked and must not conflict.
1788 auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1789 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1790 return failure();
1791
1792 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1793 return success();
1794 }
1795};
1796} // namespace
1797
1798void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1799 MLIRContext *context) {
1800 patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
1801 ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
1802 context);
1803}
1804
1805LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1806 MLIRContext *context, std::optional<Location> location,
1807 ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1808 if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
1809 inferredReturnTypes.assign({ShapeType::get(context)});
1810 else {
1811 auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
1812 int64_t rank =
1813 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1814 Type indexTy = IndexType::get(context);
1815 Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1816 inferredReturnTypes.assign({extentTensorTy});
1817 }
1818 return success();
1819}
1820
1821bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1822 if (l.size() != 1 || r.size() != 1)
1823 return false;
1824 if (l == r)
1825 return true;
1826
1827 Type lhs = l.front();
1828 Type rhs = r.front();
1829
1830 if (!llvm::isa<ShapeType, ShapedType>(lhs) ||
1831 !llvm::isa<ShapeType, ShapedType>(rhs))
1832 return false;
1833
1834 if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
1835 // Shape type is compatible with all other valid return types.
1836 return true;
1837
1838 if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1839 return true;
1840 return false;
1841}
1842
1843LogicalResult shape::ShapeOfOp::verify() {
1844 return verifyShapeOrExtentTensorOp(*this);
1845}
1846
1847//===----------------------------------------------------------------------===//
1848// SizeToIndexOp
1849//===----------------------------------------------------------------------===//
1850
1851OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
1852 // Constant values of both types, `shape.size` and `index`, are represented as
1853 // `IntegerAttr`s which makes constant folding simple.
1854 if (Attribute arg = adaptor.getArg())
1855 return arg;
1856 return OpFoldResult();
1857}
1858
1859void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1860 MLIRContext *context) {
1861 patterns.add<IndexToSizeToIndexCanonicalization>(context);
1862}
1863
1864bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1865 if (inputs.size() != 1 || outputs.size() != 1)
1866 return false;
1867 return llvm::isa<IndexType, SizeType>(inputs[0]) &&
1868 llvm::isa<IndexType>(outputs[0]);
1869}
1870
1871//===----------------------------------------------------------------------===//
1872// YieldOp
1873//===----------------------------------------------------------------------===//
1874
1875LogicalResult shape::YieldOp::verify() {
1876 auto *parentOp = (*this)->getParentOp();
1877 auto results = parentOp->getResults();
1878 auto operands = getOperands();
1879
1880 if (parentOp->getNumResults() != getNumOperands())
1881 return emitOpError() << "number of operands does not match number of "
1882 "results of its parent";
1883 for (auto e : llvm::zip(results, operands))
1884 if (std::get<0>(e).getType() != std::get<1>(e).getType())
1885 return emitOpError() << "types mismatch between yield op and its parent";
1886
1887 return success();
1888}
1889
1890//===----------------------------------------------------------------------===//
1891// SplitAtOp
1892//===----------------------------------------------------------------------===//
1893
1894LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
1895 SmallVectorImpl<OpFoldResult> &results) {
1896 if (!adaptor.getOperand() || !adaptor.getIndex())
1897 return failure();
1898 auto shapeVec = llvm::to_vector<6>(
1899 llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
1900 auto shape = llvm::ArrayRef(shapeVec);
1901 auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
1902 // Verify that the split point is in the correct range.
1903 // TODO: Constant fold to an "error".
1904 int64_t rank = shape.size();
1905 if (-rank > splitPoint || splitPoint > rank)
1906 return failure();
1907 if (splitPoint < 0)
1908 splitPoint += shape.size();
1909 Builder builder(adaptor.getOperand().getContext());
1910 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1911 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1912 return success();
1913}
1914
1915//===----------------------------------------------------------------------===//
1916// ToExtentTensorOp
1917//===----------------------------------------------------------------------===//
1918
1919OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
1920 if (!adaptor.getInput())
1921 return OpFoldResult();
1922 Builder builder(getContext());
1923 auto shape = llvm::to_vector<6>(
1924 llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
1925 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1926 builder.getIndexType());
1927 return DenseIntElementsAttr::get(type, shape);
1928}
1929
1930bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1931 if (inputs.size() != 1 || outputs.size() != 1)
1932 return false;
1933 if (auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
1934 if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
1935 inputTensor.getRank() != 1)
1936 return false;
1937 } else if (!llvm::isa<ShapeType>(inputs[0])) {
1938 return false;
1939 }
1940
1941 TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
1942 return outputTensor && llvm::isa<IndexType>(outputTensor.getElementType());
1943}
1944
1945//===----------------------------------------------------------------------===//
1946// ReduceOp
1947//===----------------------------------------------------------------------===//
1948
1949void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1950 ValueRange initVals) {
1951 OpBuilder::InsertionGuard g(builder);
1952 result.addOperands(shape);
1953 result.addOperands(initVals);
1954
1955 Region *bodyRegion = result.addRegion();
1956 Block *bodyBlock = builder.createBlock(
1957 bodyRegion, /*insertPt=*/{}, builder.getIndexType(), result.location);
1958
1959 Type elementType;
1960 if (auto tensorType = llvm::dyn_cast<TensorType>(shape.getType()))
1961 elementType = tensorType.getElementType();
1962 else
1963 elementType = SizeType::get(builder.getContext());
1964 bodyBlock->addArgument(elementType, shape.getLoc());
1965
1966 for (Value initVal : initVals) {
1967 bodyBlock->addArgument(initVal.getType(), initVal.getLoc());
1968 result.addTypes(initVal.getType());
1969 }
1970}
1971
1972LogicalResult ReduceOp::verify() {
1973 // Verify block arg types.
1974 Block &block = getRegion().front();
1975
1976 // The block takes index, extent, and aggregated values as arguments.
1977 auto blockArgsCount = getInitVals().size() + 2;
1978 if (block.getNumArguments() != blockArgsCount)
1979 return emitOpError() << "ReduceOp body is expected to have "
1980 << blockArgsCount << " arguments";
1981
1982 // The first block argument is the index and must always be of type `index`.
1983 if (!llvm::isa<IndexType>(block.getArgument(0).getType()))
1984 return emitOpError(
1985 "argument 0 of ReduceOp body is expected to be of IndexType");
1986
1987 // The second block argument is the extent and must be of type `size` or
1988 // `index`, depending on whether the reduce operation is applied to a shape or
1989 // to an extent tensor.
1990 Type extentTy = block.getArgument(1).getType();
1991 if (llvm::isa<ShapeType>(getShape().getType())) {
1992 if (!llvm::isa<SizeType>(extentTy))
1993 return emitOpError("argument 1 of ReduceOp body is expected to be of "
1994 "SizeType if the ReduceOp operates on a ShapeType");
1995 } else {
1996 if (!llvm::isa<IndexType>(extentTy))
1997 return emitOpError(
1998 "argument 1 of ReduceOp body is expected to be of IndexType if the "
1999 "ReduceOp operates on an extent tensor");
2000 }
2001
2002 for (const auto &type : llvm::enumerate(getInitVals()))
2003 if (block.getArgument(type.index() + 2).getType() != type.value().getType())
2004 return emitOpError() << "type mismatch between argument "
2005 << type.index() + 2
2006 << " of ReduceOp body and initial value "
2007 << type.index();
2008 return success();
2009}
2010
2011ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
2012 // Parse operands.
2013 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2014 Type shapeOrExtentTensorType;
2015 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
2016 OpAsmParser::Delimiter::Paren) ||
2017 parser.parseColonType(shapeOrExtentTensorType) ||
2018 parser.parseOptionalArrowTypeList(result.types))
2019 return failure();
2020
2021 // Resolve operands.
2022 auto initVals = llvm::ArrayRef(operands).drop_front();
2023 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
2024 result.operands) ||
2025 parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2026 result.operands))
2027 return failure();
2028
2029 // Parse the body.
2030 Region *body = result.addRegion();
2031 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
2032 return failure();
2033
2034 // Parse attributes.
2035 if (parser.parseOptionalAttrDict(result.attributes))
2036 return failure();
2037
2038 return success();
2039}
2040
2041void ReduceOp::print(OpAsmPrinter &p) {
2042 p << '(' << getShape() << ", " << getInitVals()
2043 << ") : " << getShape().getType();
2044 p.printOptionalArrowTypeList(getResultTypes());
2045 p << ' ';
2046 p.printRegion(getRegion());
2047 p.printOptionalAttrDict((*this)->getAttrs());
2048}
2049
2050#define GET_OP_CLASSES
2051#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
2052
2053#define GET_TYPEDEF_CLASSES
2054#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
2055

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Shape/IR/Shape.cpp