1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Dialect/Arith/Utils/Utils.h"
11#include "mlir/Dialect/MemRef/IR/MemRef.h"
12#include "mlir/Dialect/Utils/StaticValueUtils.h"
13#include "mlir/IR/AffineMap.h"
14#include "mlir/IR/Builders.h"
15#include "mlir/IR/BuiltinTypes.h"
16#include "mlir/IR/Matchers.h"
17#include "mlir/IR/OpDefinition.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/IR/TypeUtilities.h"
20#include "mlir/Interfaces/InferTypeOpInterface.h"
21#include "mlir/Interfaces/SideEffectInterfaces.h"
22#include "mlir/Interfaces/ViewLikeInterface.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallBitVector.h"
25
26using namespace mlir;
27using namespace mlir::memref;
28
29/// Materialize a single constant operation from a given attribute value with
30/// the desired resultant type.
31Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
32 Attribute value, Type type,
33 Location loc) {
34 return arith::ConstantOp::materialize(builder, value, type, loc);
35}
36
37//===----------------------------------------------------------------------===//
38// Common canonicalization pattern support logic
39//===----------------------------------------------------------------------===//
40
41/// This is a common class used for patterns of the form
42/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
43/// into the root operation directly.
44LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
45 bool folded = false;
46 for (OpOperand &operand : op->getOpOperands()) {
47 auto cast = operand.get().getDefiningOp<CastOp>();
48 if (cast && operand.get() != inner &&
49 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
50 operand.set(cast.getOperand());
51 folded = true;
52 }
53 }
54 return success(isSuccess: folded);
55}
56
57/// Return an unranked/ranked tensor type for the given unranked/ranked memref
58/// type.
59Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
60 if (auto memref = llvm::dyn_cast<MemRefType>(type))
61 return RankedTensorType::get(memref.getShape(), memref.getElementType());
62 if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
63 return UnrankedTensorType::get(memref.getElementType());
64 return NoneType::get(type.getContext());
65}
66
67OpFoldResult memref::getMixedSize(OpBuilder &builder, Location loc, Value value,
68 int64_t dim) {
69 auto memrefType = llvm::cast<MemRefType>(value.getType());
70 SmallVector<OpFoldResult> result;
71 if (memrefType.isDynamicDim(dim))
72 return builder.createOrFold<memref::DimOp>(loc, value, dim);
73
74 return builder.getIndexAttr(value: memrefType.getDimSize(dim));
75}
76
77SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder,
78 Location loc, Value value) {
79 auto memrefType = llvm::cast<MemRefType>(value.getType());
80 SmallVector<OpFoldResult> result;
81 for (int64_t i = 0; i < memrefType.getRank(); ++i)
82 result.push_back(Elt: getMixedSize(builder, loc, value, dim: i));
83 return result;
84}
85
86//===----------------------------------------------------------------------===//
87// Utility functions for propagating static information
88//===----------------------------------------------------------------------===//
89
90/// Helper function that infers the constant values from a list of \p values,
91/// a \p memRefTy, and another helper function \p getAttributes.
92/// The inferred constant values replace the related `OpFoldResult` in
93/// \p values.
94///
95/// \note This function shouldn't be used directly, instead, use the
96/// `getConstifiedMixedXXX` methods from the related operations.
97///
98/// \p getAttributes retuns a list of potentially constant values, as determined
99/// by \p isDynamic, from the given \p memRefTy. The returned list must have as
100/// many elements as \p values or be empty.
101///
102/// E.g., consider the following example:
103/// ```
104/// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
105/// memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
106/// ```
107/// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
108/// Now using this helper function with:
109/// - `values == [2, %dyn_stride]`,
110/// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
111/// - `getAttributes == getConstantStrides` (i.e., a wrapper around
112/// `getStridesAndOffset`), and
113/// - `isDynamic == ShapedType::isDynamic`
114/// Will yield: `values == [2, 1]`
115static void constifyIndexValues(
116 SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
117 MLIRContext *ctxt,
118 llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
119 llvm::function_ref<bool(int64_t)> isDynamic) {
120 SmallVector<int64_t> constValues = getAttributes(memRefTy);
121 Builder builder(ctxt);
122 for (const auto &it : llvm::enumerate(constValues)) {
123 int64_t constValue = it.value();
124 if (!isDynamic(constValue))
125 values[it.index()] = builder.getIndexAttr(constValue);
126 }
127 for (OpFoldResult &ofr : values) {
128 if (ofr.is<Attribute>()) {
129 // FIXME: We shouldn't need to do that, but right now, the static indices
130 // are created with the wrong type: `i64` instead of `index`.
131 // As a result, if we were to keep the attribute as is, we may fail to see
132 // that two attributes are equal because one would have the i64 type and
133 // the other the index type.
134 // The alternative would be to create constant indices with getI64Attr in
135 // this and the previous loop, but it doesn't logically make sense (we are
136 // dealing with indices here) and would only strenghten the inconsistency
137 // around how static indices are created (some places use getI64Attr,
138 // others use getIndexAttr).
139 // The workaround here is to stick to the IndexAttr type for all the
140 // values, hence we recreate the attribute even when it is already static
141 // to make sure the type is consistent.
142 ofr = builder.getIndexAttr(
143 value: llvm::cast<IntegerAttr>(ofr.get<Attribute>()).getInt());
144 continue;
145 }
146 std::optional<int64_t> maybeConstant =
147 getConstantIntValue(ofr: ofr.get<Value>());
148 if (maybeConstant)
149 ofr = builder.getIndexAttr(*maybeConstant);
150 }
151}
152
153/// Wrapper around `getShape` that conforms to the function signature
154/// expected for `getAttributes` in `constifyIndexValues`.
155static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) {
156 ArrayRef<int64_t> sizes = memRefTy.getShape();
157 return SmallVector<int64_t>(sizes.begin(), sizes.end());
158}
159
160/// Wrapper around `getStridesAndOffset` that returns only the offset and
161/// conforms to the function signature expected for `getAttributes` in
162/// `constifyIndexValues`.
163static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
164 SmallVector<int64_t> strides;
165 int64_t offset;
166 LogicalResult hasStaticInformation =
167 getStridesAndOffset(memrefType, strides, offset);
168 if (failed(result: hasStaticInformation))
169 return SmallVector<int64_t>();
170 return SmallVector<int64_t>(1, offset);
171}
172
173/// Wrapper around `getStridesAndOffset` that returns only the strides and
174/// conforms to the function signature expected for `getAttributes` in
175/// `constifyIndexValues`.
176static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
177 SmallVector<int64_t> strides;
178 int64_t offset;
179 LogicalResult hasStaticInformation =
180 getStridesAndOffset(memrefType, strides, offset);
181 if (failed(result: hasStaticInformation))
182 return SmallVector<int64_t>();
183 return strides;
184}
185
186//===----------------------------------------------------------------------===//
187// AllocOp / AllocaOp
188//===----------------------------------------------------------------------===//
189
190void AllocOp::getAsmResultNames(
191 function_ref<void(Value, StringRef)> setNameFn) {
192 setNameFn(getResult(), "alloc");
193}
194
195void AllocaOp::getAsmResultNames(
196 function_ref<void(Value, StringRef)> setNameFn) {
197 setNameFn(getResult(), "alloca");
198}
199
200template <typename AllocLikeOp>
201static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
202 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
203 "applies to only alloc or alloca");
204 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
205 if (!memRefType)
206 return op.emitOpError("result must be a memref");
207
208 if (static_cast<int64_t>(op.getDynamicSizes().size()) !=
209 memRefType.getNumDynamicDims())
210 return op.emitOpError("dimension operand count does not equal memref "
211 "dynamic dimension count");
212
213 unsigned numSymbols = 0;
214 if (!memRefType.getLayout().isIdentity())
215 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
216 if (op.getSymbolOperands().size() != numSymbols)
217 return op.emitOpError("symbol operand count does not equal memref symbol "
218 "count: expected ")
219 << numSymbols << ", got " << op.getSymbolOperands().size();
220
221 return success();
222}
223
224LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
225
226LogicalResult AllocaOp::verify() {
227 // An alloca op needs to have an ancestor with an allocation scope trait.
228 if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
229 return emitOpError(
230 "requires an ancestor op with AutomaticAllocationScope trait");
231
232 return verifyAllocLikeOp(*this);
233}
234
235namespace {
236/// Fold constant dimensions into an alloc like operation.
237template <typename AllocLikeOp>
238struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
239 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
240
241 LogicalResult matchAndRewrite(AllocLikeOp alloc,
242 PatternRewriter &rewriter) const override {
243 // Check to see if any dimensions operands are constants. If so, we can
244 // substitute and drop them.
245 if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
246 APInt constSizeArg;
247 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
248 return false;
249 return constSizeArg.isNonNegative();
250 }))
251 return failure();
252
253 auto memrefType = alloc.getType();
254
255 // Ok, we have one or more constant operands. Collect the non-constant ones
256 // and keep track of the resultant memref type to build.
257 SmallVector<int64_t, 4> newShapeConstants;
258 newShapeConstants.reserve(N: memrefType.getRank());
259 SmallVector<Value, 4> dynamicSizes;
260
261 unsigned dynamicDimPos = 0;
262 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
263 int64_t dimSize = memrefType.getDimSize(dim);
264 // If this is already static dimension, keep it.
265 if (!ShapedType::isDynamic(dimSize)) {
266 newShapeConstants.push_back(Elt: dimSize);
267 continue;
268 }
269 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
270 APInt constSizeArg;
271 if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
272 constSizeArg.isNonNegative()) {
273 // Dynamic shape dimension will be folded.
274 newShapeConstants.push_back(Elt: constSizeArg.getZExtValue());
275 } else {
276 // Dynamic shape dimension not folded; copy dynamicSize from old memref.
277 newShapeConstants.push_back(ShapedType::kDynamic);
278 dynamicSizes.push_back(Elt: dynamicSize);
279 }
280 dynamicDimPos++;
281 }
282
283 // Create new memref type (which will have fewer dynamic dimensions).
284 MemRefType newMemRefType =
285 MemRefType::Builder(memrefType).setShape(newShapeConstants);
286 assert(static_cast<int64_t>(dynamicSizes.size()) ==
287 newMemRefType.getNumDynamicDims());
288
289 // Create and insert the alloc op for the new memref.
290 auto newAlloc = rewriter.create<AllocLikeOp>(
291 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
292 alloc.getAlignmentAttr());
293 // Insert a cast so we have the same type as the old alloc.
294 rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
295 return success();
296 }
297};
298
299/// Fold alloc operations with no users or only store and dealloc uses.
300template <typename T>
301struct SimplifyDeadAlloc : public OpRewritePattern<T> {
302 using OpRewritePattern<T>::OpRewritePattern;
303
304 LogicalResult matchAndRewrite(T alloc,
305 PatternRewriter &rewriter) const override {
306 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
307 if (auto storeOp = dyn_cast<StoreOp>(op))
308 return storeOp.getValue() == alloc;
309 return !isa<DeallocOp>(op);
310 }))
311 return failure();
312
313 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
314 rewriter.eraseOp(op: user);
315
316 rewriter.eraseOp(op: alloc);
317 return success();
318 }
319};
320} // namespace
321
322void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
323 MLIRContext *context) {
324 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
325}
326
327void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
328 MLIRContext *context) {
329 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
330 context);
331}
332
333//===----------------------------------------------------------------------===//
334// ReallocOp
335//===----------------------------------------------------------------------===//
336
337LogicalResult ReallocOp::verify() {
338 auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
339 MemRefType resultType = getType();
340
341 // The source memref should have identity layout (or none).
342 if (!sourceType.getLayout().isIdentity())
343 return emitError("unsupported layout for source memref type ")
344 << sourceType;
345
346 // The result memref should have identity layout (or none).
347 if (!resultType.getLayout().isIdentity())
348 return emitError("unsupported layout for result memref type ")
349 << resultType;
350
351 // The source memref and the result memref should be in the same memory space.
352 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
353 return emitError("different memory spaces specified for source memref "
354 "type ")
355 << sourceType << " and result memref type " << resultType;
356
357 // The source memref and the result memref should have the same element type.
358 if (sourceType.getElementType() != resultType.getElementType())
359 return emitError("different element types specified for source memref "
360 "type ")
361 << sourceType << " and result memref type " << resultType;
362
363 // Verify that we have the dynamic dimension operand when it is needed.
364 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
365 return emitError("missing dimension operand for result type ")
366 << resultType;
367 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
368 return emitError("unnecessary dimension operand for result type ")
369 << resultType;
370
371 return success();
372}
373
374void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
375 MLIRContext *context) {
376 results.add<SimplifyDeadAlloc<ReallocOp>>(context);
377}
378
379//===----------------------------------------------------------------------===//
380// AllocaScopeOp
381//===----------------------------------------------------------------------===//
382
383void AllocaScopeOp::print(OpAsmPrinter &p) {
384 bool printBlockTerminators = false;
385
386 p << ' ';
387 if (!getResults().empty()) {
388 p << " -> (" << getResultTypes() << ")";
389 printBlockTerminators = true;
390 }
391 p << ' ';
392 p.printRegion(getBodyRegion(),
393 /*printEntryBlockArgs=*/false,
394 /*printBlockTerminators=*/printBlockTerminators);
395 p.printOptionalAttrDict((*this)->getAttrs());
396}
397
398ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
399 // Create a region for the body.
400 result.regions.reserve(1);
401 Region *bodyRegion = result.addRegion();
402
403 // Parse optional results type list.
404 if (parser.parseOptionalArrowTypeList(result.types))
405 return failure();
406
407 // Parse the body region.
408 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
409 return failure();
410 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
411 result.location);
412
413 // Parse the optional attribute list.
414 if (parser.parseOptionalAttrDict(result.attributes))
415 return failure();
416
417 return success();
418}
419
420void AllocaScopeOp::getSuccessorRegions(
421 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
422 if (!point.isParent()) {
423 regions.push_back(RegionSuccessor(getResults()));
424 return;
425 }
426
427 regions.push_back(RegionSuccessor(&getBodyRegion()));
428}
429
430/// Given an operation, return whether this op is guaranteed to
431/// allocate an AutomaticAllocationScopeResource
432static bool isGuaranteedAutomaticAllocation(Operation *op) {
433 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
434 if (!interface)
435 return false;
436 for (auto res : op->getResults()) {
437 if (auto effect =
438 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
439 if (isa<SideEffects::AutomaticAllocationScopeResource>(
440 effect->getResource()))
441 return true;
442 }
443 }
444 return false;
445}
446
447/// Given an operation, return whether this op itself could
448/// allocate an AutomaticAllocationScopeResource. Note that
449/// this will not check whether an operation contained within
450/// the op can allocate.
451static bool isOpItselfPotentialAutomaticAllocation(Operation *op) {
452 // This op itself doesn't create a stack allocation,
453 // the inner allocation should be handled separately.
454 if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
455 return false;
456 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
457 if (!interface)
458 return true;
459 for (auto res : op->getResults()) {
460 if (auto effect =
461 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
462 if (isa<SideEffects::AutomaticAllocationScopeResource>(
463 effect->getResource()))
464 return true;
465 }
466 }
467 return false;
468}
469
470/// Return whether this op is the last non terminating op
471/// in a region. That is to say, it is in a one-block region
472/// and is only followed by a terminator. This prevents
473/// extending the lifetime of allocations.
474static bool lastNonTerminatorInRegion(Operation *op) {
475 return op->getNextNode() == op->getBlock()->getTerminator() &&
476 op->getParentRegion()->getBlocks().size() == 1;
477}
478
479/// Inline an AllocaScopeOp if either the direct parent is an allocation scope
480/// or it contains no allocation.
481struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
482 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
483
484 LogicalResult matchAndRewrite(AllocaScopeOp op,
485 PatternRewriter &rewriter) const override {
486 bool hasPotentialAlloca =
487 op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
488 if (alloc == op)
489 return WalkResult::advance();
490 if (isOpItselfPotentialAutomaticAllocation(op: alloc))
491 return WalkResult::interrupt();
492 if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
493 return WalkResult::skip();
494 return WalkResult::advance();
495 }).wasInterrupted();
496
497 // If this contains no potential allocation, it is always legal to
498 // inline. Otherwise, consider two conditions:
499 if (hasPotentialAlloca) {
500 // If the parent isn't an allocation scope, or we are not the last
501 // non-terminator op in the parent, we will extend the lifetime.
502 if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
503 return failure();
504 if (!lastNonTerminatorInRegion(op))
505 return failure();
506 }
507
508 Block *block = &op.getRegion().front();
509 Operation *terminator = block->getTerminator();
510 ValueRange results = terminator->getOperands();
511 rewriter.inlineBlockBefore(block, op);
512 rewriter.replaceOp(op, results);
513 rewriter.eraseOp(op: terminator);
514 return success();
515 }
516};
517
518/// Move allocations into an allocation scope, if it is legal to
519/// move them (e.g. their operands are available at the location
520/// the op would be moved to).
521struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
522 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
523
524 LogicalResult matchAndRewrite(AllocaScopeOp op,
525 PatternRewriter &rewriter) const override {
526
527 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
528 return failure();
529
530 Operation *lastParentWithoutScope = op->getParentOp();
531
532 if (!lastParentWithoutScope ||
533 lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
534 return failure();
535
536 // Only apply to if this is this last non-terminator
537 // op in the block (lest lifetime be extended) of a one
538 // block region
539 if (!lastNonTerminatorInRegion(op) ||
540 !lastNonTerminatorInRegion(op: lastParentWithoutScope))
541 return failure();
542
543 while (!lastParentWithoutScope->getParentOp()
544 ->hasTrait<OpTrait::AutomaticAllocationScope>()) {
545 lastParentWithoutScope = lastParentWithoutScope->getParentOp();
546 if (!lastParentWithoutScope ||
547 !lastNonTerminatorInRegion(op: lastParentWithoutScope))
548 return failure();
549 }
550 assert(lastParentWithoutScope->getParentOp()
551 ->hasTrait<OpTrait::AutomaticAllocationScope>());
552
553 Region *containingRegion = nullptr;
554 for (auto &r : lastParentWithoutScope->getRegions()) {
555 if (r.isAncestor(op->getParentRegion())) {
556 assert(containingRegion == nullptr &&
557 "only one region can contain the op");
558 containingRegion = &r;
559 }
560 }
561 assert(containingRegion && "op must be contained in a region");
562
563 SmallVector<Operation *> toHoist;
564 op->walk([&](Operation *alloc) {
565 if (!isGuaranteedAutomaticAllocation(op: alloc))
566 return WalkResult::skip();
567
568 // If any operand is not defined before the location of
569 // lastParentWithoutScope (i.e. where we would hoist to), skip.
570 if (llvm::any_of(Range: alloc->getOperands(), P: [&](Value v) {
571 return containingRegion->isAncestor(other: v.getParentRegion());
572 }))
573 return WalkResult::skip();
574 toHoist.push_back(Elt: alloc);
575 return WalkResult::advance();
576 });
577
578 if (toHoist.empty())
579 return failure();
580 rewriter.setInsertionPoint(lastParentWithoutScope);
581 for (auto *op : toHoist) {
582 auto *cloned = rewriter.clone(op&: *op);
583 rewriter.replaceOp(op, newValues: cloned->getResults());
584 }
585 return success();
586 }
587};
588
589void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
590 MLIRContext *context) {
591 results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
592}
593
594//===----------------------------------------------------------------------===//
595// AssumeAlignmentOp
596//===----------------------------------------------------------------------===//
597
598LogicalResult AssumeAlignmentOp::verify() {
599 if (!llvm::isPowerOf2_32(getAlignment()))
600 return emitOpError("alignment must be power of 2");
601 return success();
602}
603
604//===----------------------------------------------------------------------===//
605// CastOp
606//===----------------------------------------------------------------------===//
607
608void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
609 setNameFn(getResult(), "cast");
610}
611
612/// Determines whether MemRef_CastOp casts to a more dynamic version of the
613/// source memref. This is useful to fold a memref.cast into a consuming op
614/// and implement canonicalization patterns for ops in different dialects that
615/// may consume the results of memref.cast operations. Such foldable memref.cast
616/// operations are typically inserted as `view` and `subview` ops are
617/// canonicalized, to preserve the type compatibility of their uses.
618///
619/// Returns true when all conditions are met:
620/// 1. source and result are ranked memrefs with strided semantics and same
621/// element type and rank.
622/// 2. each of the source's size, offset or stride has more static information
623/// than the corresponding result's size, offset or stride.
624///
625/// Example 1:
626/// ```mlir
627/// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
628/// %2 = consumer %1 ... : memref<?x?xf32> ...
629/// ```
630///
631/// may fold into:
632///
633/// ```mlir
634/// %2 = consumer %0 ... : memref<8x16xf32> ...
635/// ```
636///
637/// Example 2:
638/// ```
639/// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
640/// to memref<?x?xf32>
641/// consumer %1 : memref<?x?xf32> ...
642/// ```
643///
644/// may fold into:
645///
646/// ```
647/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
648/// ```
649bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
650 MemRefType sourceType =
651 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
652 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
653
654 // Requires ranked MemRefType.
655 if (!sourceType || !resultType)
656 return false;
657
658 // Requires same elemental type.
659 if (sourceType.getElementType() != resultType.getElementType())
660 return false;
661
662 // Requires same rank.
663 if (sourceType.getRank() != resultType.getRank())
664 return false;
665
666 // Only fold casts between strided memref forms.
667 int64_t sourceOffset, resultOffset;
668 SmallVector<int64_t, 4> sourceStrides, resultStrides;
669 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
670 failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
671 return false;
672
673 // If cast is towards more static sizes along any dimension, don't fold.
674 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
675 auto ss = std::get<0>(it), st = std::get<1>(it);
676 if (ss != st)
677 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
678 return false;
679 }
680
681 // If cast is towards more static offset along any dimension, don't fold.
682 if (sourceOffset != resultOffset)
683 if (ShapedType::isDynamic(sourceOffset) &&
684 !ShapedType::isDynamic(resultOffset))
685 return false;
686
687 // If cast is towards more static strides along any dimension, don't fold.
688 for (auto it : llvm::zip(sourceStrides, resultStrides)) {
689 auto ss = std::get<0>(it), st = std::get<1>(it);
690 if (ss != st)
691 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
692 return false;
693 }
694
695 return true;
696}
697
698bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
699 if (inputs.size() != 1 || outputs.size() != 1)
700 return false;
701 Type a = inputs.front(), b = outputs.front();
702 auto aT = llvm::dyn_cast<MemRefType>(a);
703 auto bT = llvm::dyn_cast<MemRefType>(b);
704
705 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
706 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
707
708 if (aT && bT) {
709 if (aT.getElementType() != bT.getElementType())
710 return false;
711 if (aT.getLayout() != bT.getLayout()) {
712 int64_t aOffset, bOffset;
713 SmallVector<int64_t, 4> aStrides, bStrides;
714 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
715 failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
716 aStrides.size() != bStrides.size())
717 return false;
718
719 // Strides along a dimension/offset are compatible if the value in the
720 // source memref is static and the value in the target memref is the
721 // same. They are also compatible if either one is dynamic (see
722 // description of MemRefCastOp for details).
723 auto checkCompatible = [](int64_t a, int64_t b) {
724 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
725 };
726 if (!checkCompatible(aOffset, bOffset))
727 return false;
728 for (const auto &aStride : enumerate(aStrides))
729 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
730 return false;
731 }
732 if (aT.getMemorySpace() != bT.getMemorySpace())
733 return false;
734
735 // They must have the same rank, and any specified dimensions must match.
736 if (aT.getRank() != bT.getRank())
737 return false;
738
739 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
740 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
741 if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
742 aDim != bDim)
743 return false;
744 }
745 return true;
746 } else {
747 if (!aT && !uaT)
748 return false;
749 if (!bT && !ubT)
750 return false;
751 // Unranked to unranked casting is unsupported
752 if (uaT && ubT)
753 return false;
754
755 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
756 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
757 if (aEltType != bEltType)
758 return false;
759
760 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
761 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
762 return aMemSpace == bMemSpace;
763 }
764
765 return false;
766}
767
768OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
769 return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
770}
771
772//===----------------------------------------------------------------------===//
773// CopyOp
774//===----------------------------------------------------------------------===//
775
776namespace {
777/// If the source/target of a CopyOp is a CastOp that does not modify the shape
778/// and element type, the cast can be skipped. Such CastOps only cast the layout
779/// of the type.
780struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
781 using OpRewritePattern<CopyOp>::OpRewritePattern;
782
783 LogicalResult matchAndRewrite(CopyOp copyOp,
784 PatternRewriter &rewriter) const override {
785 bool modified = false;
786
787 // Check source.
788 if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
789 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
790 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
791
792 if (fromType && toType) {
793 if (fromType.getShape() == toType.getShape() &&
794 fromType.getElementType() == toType.getElementType()) {
795 rewriter.modifyOpInPlace(copyOp, [&] {
796 copyOp.getSourceMutable().assign(castOp.getSource());
797 });
798 modified = true;
799 }
800 }
801 }
802
803 // Check target.
804 if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
805 auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
806 auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
807
808 if (fromType && toType) {
809 if (fromType.getShape() == toType.getShape() &&
810 fromType.getElementType() == toType.getElementType()) {
811 rewriter.modifyOpInPlace(copyOp, [&] {
812 copyOp.getTargetMutable().assign(castOp.getSource());
813 });
814 modified = true;
815 }
816 }
817 }
818
819 return success(isSuccess: modified);
820 }
821};
822
823/// Fold memref.copy(%x, %x).
824struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
825 using OpRewritePattern<CopyOp>::OpRewritePattern;
826
827 LogicalResult matchAndRewrite(CopyOp copyOp,
828 PatternRewriter &rewriter) const override {
829 if (copyOp.getSource() != copyOp.getTarget())
830 return failure();
831
832 rewriter.eraseOp(op: copyOp);
833 return success();
834 }
835};
836} // namespace
837
838void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
839 MLIRContext *context) {
840 results.add<FoldCopyOfCast, FoldSelfCopy>(context);
841}
842
843LogicalResult CopyOp::fold(FoldAdaptor adaptor,
844 SmallVectorImpl<OpFoldResult> &results) {
845 /// copy(memrefcast) -> copy
846 bool folded = false;
847 Operation *op = *this;
848 for (OpOperand &operand : op->getOpOperands()) {
849 auto castOp = operand.get().getDefiningOp<memref::CastOp>();
850 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
851 operand.set(castOp.getOperand());
852 folded = true;
853 }
854 }
855 return success(folded);
856}
857
858//===----------------------------------------------------------------------===//
859// DeallocOp
860//===----------------------------------------------------------------------===//
861
862LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
863 SmallVectorImpl<OpFoldResult> &results) {
864 /// dealloc(memrefcast) -> dealloc
865 return foldMemRefCast(*this);
866}
867
868//===----------------------------------------------------------------------===//
869// DimOp
870//===----------------------------------------------------------------------===//
871
872void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
873 setNameFn(getResult(), "dim");
874}
875
876void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
877 int64_t index) {
878 auto loc = result.location;
879 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
880 build(builder, result, source, indexValue);
881}
882
883std::optional<int64_t> DimOp::getConstantIndex() {
884 return getConstantIntValue(getIndex());
885}
886
887Speculation::Speculatability DimOp::getSpeculatability() {
888 auto constantIndex = getConstantIndex();
889 if (!constantIndex)
890 return Speculation::NotSpeculatable;
891
892 auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
893 if (!rankedSourceType)
894 return Speculation::NotSpeculatable;
895
896 if (rankedSourceType.getRank() <= constantIndex)
897 return Speculation::NotSpeculatable;
898
899 return Speculation::Speculatable;
900}
901
902/// Return a map with key being elements in `vals` and data being number of
903/// occurences of it. Use std::map, since the `vals` here are strides and the
904/// dynamic stride value is the same as the tombstone value for
905/// `DenseMap<int64_t>`.
906static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
907 std::map<int64_t, unsigned> numOccurences;
908 for (auto val : vals)
909 numOccurences[val]++;
910 return numOccurences;
911}
912
913/// Given the `originalType` and a `candidateReducedType` whose shape is assumed
914/// to be a subset of `originalType` with some `1` entries erased, return the
915/// set of indices that specifies which of the entries of `originalShape` are
916/// dropped to obtain `reducedShape`.
917/// This accounts for cases where there are multiple unit-dims, but only a
918/// subset of those are dropped. For MemRefTypes these can be disambiguated
919/// using the strides. If a dimension is dropped the stride must be dropped too.
920static FailureOr<llvm::SmallBitVector>
921computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
922 ArrayRef<OpFoldResult> sizes) {
923 llvm::SmallBitVector unusedDims(originalType.getRank());
924 if (originalType.getRank() == reducedType.getRank())
925 return unusedDims;
926
927 for (const auto &dim : llvm::enumerate(First&: sizes))
928 if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val: dim.value()))
929 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
930 unusedDims.set(dim.index());
931
932 // Early exit for the case where the number of unused dims matches the number
933 // of ranks reduced.
934 if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
935 originalType.getRank())
936 return unusedDims;
937
938 SmallVector<int64_t> originalStrides, candidateStrides;
939 int64_t originalOffset, candidateOffset;
940 if (failed(
941 getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
942 failed(
943 getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
944 return failure();
945
946 // For memrefs, a dimension is truly dropped if its corresponding stride is
947 // also dropped. This is particularly important when more than one of the dims
948 // is 1. Track the number of occurences of the strides in the original type
949 // and the candidate type. For each unused dim that stride should not be
950 // present in the candidate type. Note that there could be multiple dimensions
951 // that have the same size. We dont need to exactly figure out which dim
952 // corresponds to which stride, we just need to verify that the number of
953 // reptitions of a stride in the original + number of unused dims with that
954 // stride == number of repititions of a stride in the candidate.
955 std::map<int64_t, unsigned> currUnaccountedStrides =
956 getNumOccurences(vals: originalStrides);
957 std::map<int64_t, unsigned> candidateStridesNumOccurences =
958 getNumOccurences(vals: candidateStrides);
959 for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
960 if (!unusedDims.test(Idx: dim))
961 continue;
962 int64_t originalStride = originalStrides[dim];
963 if (currUnaccountedStrides[originalStride] >
964 candidateStridesNumOccurences[originalStride]) {
965 // This dim can be treated as dropped.
966 currUnaccountedStrides[originalStride]--;
967 continue;
968 }
969 if (currUnaccountedStrides[originalStride] ==
970 candidateStridesNumOccurences[originalStride]) {
971 // The stride for this is not dropped. Keep as is.
972 unusedDims.reset(Idx: dim);
973 continue;
974 }
975 if (currUnaccountedStrides[originalStride] <
976 candidateStridesNumOccurences[originalStride]) {
977 // This should never happen. Cant have a stride in the reduced rank type
978 // that wasnt in the original one.
979 return failure();
980 }
981 }
982
983 if ((int64_t)unusedDims.count() + reducedType.getRank() !=
984 originalType.getRank())
985 return failure();
986 return unusedDims;
987}
988
989llvm::SmallBitVector SubViewOp::getDroppedDims() {
990 MemRefType sourceType = getSourceType();
991 MemRefType resultType = getType();
992 FailureOr<llvm::SmallBitVector> unusedDims =
993 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
994 assert(succeeded(unusedDims) && "unable to find unused dims of subview");
995 return *unusedDims;
996}
997
998OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
999 // All forms of folding require a known index.
1000 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1001 if (!index)
1002 return {};
1003
1004 // Folding for unranked types (UnrankedMemRefType) is not supported.
1005 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1006 if (!memrefType)
1007 return {};
1008
1009 // Out of bound indices produce undefined behavior but are still valid IR.
1010 // Don't choke on them.
1011 int64_t indexVal = index.getInt();
1012 if (indexVal < 0 || indexVal >= memrefType.getRank())
1013 return {};
1014
1015 // Fold if the shape extent along the given index is known.
1016 if (!memrefType.isDynamicDim(index.getInt())) {
1017 Builder builder(getContext());
1018 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1019 }
1020
1021 // The size at the given index is now known to be a dynamic size.
1022 unsigned unsignedIndex = index.getValue().getZExtValue();
1023
1024 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1025 Operation *definingOp = getSource().getDefiningOp();
1026
1027 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1028 return *(alloc.getDynamicSizes().begin() +
1029 memrefType.getDynamicDimIndex(unsignedIndex));
1030
1031 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1032 return *(alloca.getDynamicSizes().begin() +
1033 memrefType.getDynamicDimIndex(unsignedIndex));
1034
1035 if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1036 return *(view.getDynamicSizes().begin() +
1037 memrefType.getDynamicDimIndex(unsignedIndex));
1038
1039 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1040 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1041 unsigned resultIndex = 0;
1042 unsigned sourceRank = subview.getSourceType().getRank();
1043 unsigned sourceIndex = 0;
1044 for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1045 if (unusedDims.test(i))
1046 continue;
1047 if (resultIndex == unsignedIndex) {
1048 sourceIndex = i;
1049 break;
1050 }
1051 resultIndex++;
1052 }
1053 assert(subview.isDynamicSize(sourceIndex) &&
1054 "expected dynamic subview size");
1055 return subview.getDynamicSize(sourceIndex);
1056 }
1057
1058 if (auto sizeInterface =
1059 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1060 assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1061 "Expected dynamic subview size");
1062 return sizeInterface.getDynamicSize(unsignedIndex);
1063 }
1064
1065 // dim(memrefcast) -> dim
1066 if (succeeded(foldMemRefCast(*this)))
1067 return getResult();
1068
1069 return {};
1070}
1071
1072namespace {
1073/// Fold dim of a memref reshape operation to a load into the reshape's shape
1074/// operand.
1075struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1076 using OpRewritePattern<DimOp>::OpRewritePattern;
1077
1078 LogicalResult matchAndRewrite(DimOp dim,
1079 PatternRewriter &rewriter) const override {
1080 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1081
1082 if (!reshape)
1083 return rewriter.notifyMatchFailure(
1084 dim, "Dim op is not defined by a reshape op.");
1085
1086 // dim of a memref reshape can be folded if dim.getIndex() dominates the
1087 // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1088 // cheaply check that either of the following conditions hold:
1089 // 1. dim.getIndex() is defined in the same block as reshape but before
1090 // reshape.
1091 // 2. dim.getIndex() is defined in a parent block of
1092 // reshape.
1093
1094 // Check condition 1
1095 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1096 if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1097 if (reshape->isBeforeInBlock(definingOp)) {
1098 return rewriter.notifyMatchFailure(
1099 dim,
1100 "dim.getIndex is not defined before reshape in the same block.");
1101 }
1102 } // else dim.getIndex is a block argument to reshape->getBlock and
1103 // dominates reshape
1104 } // Check condition 2
1105 else if (dim->getBlock() != reshape->getBlock() &&
1106 !dim.getIndex().getParentRegion()->isProperAncestor(
1107 reshape->getParentRegion())) {
1108 // If dim and reshape are in the same block but dim.getIndex() isn't, we
1109 // already know dim.getIndex() dominates reshape without calling
1110 // `isProperAncestor`
1111 return rewriter.notifyMatchFailure(
1112 dim, "dim.getIndex does not dominate reshape.");
1113 }
1114
1115 // Place the load directly after the reshape to ensure that the shape memref
1116 // was not mutated.
1117 rewriter.setInsertionPointAfter(reshape);
1118 Location loc = dim.getLoc();
1119 Value load =
1120 rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1121 if (load.getType() != dim.getType())
1122 load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
1123 rewriter.replaceOp(dim, load);
1124 return success();
1125 }
1126};
1127
1128} // namespace
1129
1130void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1131 MLIRContext *context) {
1132 results.add<DimOfMemRefReshape>(context);
1133}
1134
1135// ---------------------------------------------------------------------------
1136// DmaStartOp
1137// ---------------------------------------------------------------------------
1138
1139void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1140 Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1141 ValueRange destIndices, Value numElements,
1142 Value tagMemRef, ValueRange tagIndices, Value stride,
1143 Value elementsPerStride) {
1144 result.addOperands(srcMemRef);
1145 result.addOperands(srcIndices);
1146 result.addOperands(destMemRef);
1147 result.addOperands(destIndices);
1148 result.addOperands({numElements, tagMemRef});
1149 result.addOperands(tagIndices);
1150 if (stride)
1151 result.addOperands({stride, elementsPerStride});
1152}
1153
1154void DmaStartOp::print(OpAsmPrinter &p) {
1155 p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1156 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1157 << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1158 if (isStrided())
1159 p << ", " << getStride() << ", " << getNumElementsPerStride();
1160
1161 p.printOptionalAttrDict((*this)->getAttrs());
1162 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1163 << ", " << getTagMemRef().getType();
1164}
1165
1166// Parse DmaStartOp.
1167// Ex:
1168// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1169// %tag[%index], %stride, %num_elt_per_stride :
1170// : memref<3076 x f32, 0>,
1171// memref<1024 x f32, 2>,
1172// memref<1 x i32>
1173//
1174ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1175 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1176 SmallVector<OpAsmParser::UnresolvedOperand, 4> srcIndexInfos;
1177 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1178 SmallVector<OpAsmParser::UnresolvedOperand, 4> dstIndexInfos;
1179 OpAsmParser::UnresolvedOperand numElementsInfo;
1180 OpAsmParser::UnresolvedOperand tagMemrefInfo;
1181 SmallVector<OpAsmParser::UnresolvedOperand, 4> tagIndexInfos;
1182 SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo;
1183
1184 SmallVector<Type, 3> types;
1185 auto indexType = parser.getBuilder().getIndexType();
1186
1187 // Parse and resolve the following list of operands:
1188 // *) source memref followed by its indices (in square brackets).
1189 // *) destination memref followed by its indices (in square brackets).
1190 // *) dma size in KiB.
1191 if (parser.parseOperand(srcMemRefInfo) ||
1192 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1193 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1194 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1195 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1196 parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1197 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1198 return failure();
1199
1200 // Parse optional stride and elements per stride.
1201 if (parser.parseTrailingOperandList(strideInfo))
1202 return failure();
1203
1204 bool isStrided = strideInfo.size() == 2;
1205 if (!strideInfo.empty() && !isStrided) {
1206 return parser.emitError(parser.getNameLoc(),
1207 "expected two stride related operands");
1208 }
1209
1210 if (parser.parseColonTypeList(types))
1211 return failure();
1212 if (types.size() != 3)
1213 return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1214
1215 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1216 parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1217 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1218 parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1219 // size should be an index.
1220 parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1221 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1222 // tag indices should be index.
1223 parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1224 return failure();
1225
1226 if (isStrided) {
1227 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1228 return failure();
1229 }
1230
1231 return success();
1232}
1233
1234LogicalResult DmaStartOp::verify() {
1235 unsigned numOperands = getNumOperands();
1236
1237 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1238 // the number of elements.
1239 if (numOperands < 4)
1240 return emitOpError("expected at least 4 operands");
1241
1242 // Check types of operands. The order of these calls is important: the later
1243 // calls rely on some type properties to compute the operand position.
1244 // 1. Source memref.
1245 if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1246 return emitOpError("expected source to be of memref type");
1247 if (numOperands < getSrcMemRefRank() + 4)
1248 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1249 << " operands";
1250 if (!getSrcIndices().empty() &&
1251 !llvm::all_of(getSrcIndices().getTypes(),
1252 [](Type t) { return t.isIndex(); }))
1253 return emitOpError("expected source indices to be of index type");
1254
1255 // 2. Destination memref.
1256 if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1257 return emitOpError("expected destination to be of memref type");
1258 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1259 if (numOperands < numExpectedOperands)
1260 return emitOpError() << "expected at least " << numExpectedOperands
1261 << " operands";
1262 if (!getDstIndices().empty() &&
1263 !llvm::all_of(getDstIndices().getTypes(),
1264 [](Type t) { return t.isIndex(); }))
1265 return emitOpError("expected destination indices to be of index type");
1266
1267 // 3. Number of elements.
1268 if (!getNumElements().getType().isIndex())
1269 return emitOpError("expected num elements to be of index type");
1270
1271 // 4. Tag memref.
1272 if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1273 return emitOpError("expected tag to be of memref type");
1274 numExpectedOperands += getTagMemRefRank();
1275 if (numOperands < numExpectedOperands)
1276 return emitOpError() << "expected at least " << numExpectedOperands
1277 << " operands";
1278 if (!getTagIndices().empty() &&
1279 !llvm::all_of(getTagIndices().getTypes(),
1280 [](Type t) { return t.isIndex(); }))
1281 return emitOpError("expected tag indices to be of index type");
1282
1283 // Optional stride-related operands must be either both present or both
1284 // absent.
1285 if (numOperands != numExpectedOperands &&
1286 numOperands != numExpectedOperands + 2)
1287 return emitOpError("incorrect number of operands");
1288
1289 // 5. Strides.
1290 if (isStrided()) {
1291 if (!getStride().getType().isIndex() ||
1292 !getNumElementsPerStride().getType().isIndex())
1293 return emitOpError(
1294 "expected stride and num elements per stride to be of type index");
1295 }
1296
1297 return success();
1298}
1299
1300LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1301 SmallVectorImpl<OpFoldResult> &results) {
1302 /// dma_start(memrefcast) -> dma_start
1303 return foldMemRefCast(*this);
1304}
1305
1306// ---------------------------------------------------------------------------
1307// DmaWaitOp
1308// ---------------------------------------------------------------------------
1309
1310LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1311 SmallVectorImpl<OpFoldResult> &results) {
1312 /// dma_wait(memrefcast) -> dma_wait
1313 return foldMemRefCast(*this);
1314}
1315
1316LogicalResult DmaWaitOp::verify() {
1317 // Check that the number of tag indices matches the tagMemRef rank.
1318 unsigned numTagIndices = getTagIndices().size();
1319 unsigned tagMemRefRank = getTagMemRefRank();
1320 if (numTagIndices != tagMemRefRank)
1321 return emitOpError() << "expected tagIndices to have the same number of "
1322 "elements as the tagMemRef rank, expected "
1323 << tagMemRefRank << ", but got " << numTagIndices;
1324 return success();
1325}
1326
1327//===----------------------------------------------------------------------===//
1328// ExtractAlignedPointerAsIndexOp
1329//===----------------------------------------------------------------------===//
1330
1331void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1332 function_ref<void(Value, StringRef)> setNameFn) {
1333 setNameFn(getResult(), "intptr");
1334}
1335
1336//===----------------------------------------------------------------------===//
1337// ExtractStridedMetadataOp
1338//===----------------------------------------------------------------------===//
1339
1340/// The number and type of the results are inferred from the
1341/// shape of the source.
1342LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1343 MLIRContext *context, std::optional<Location> location,
1344 ExtractStridedMetadataOp::Adaptor adaptor,
1345 SmallVectorImpl<Type> &inferredReturnTypes) {
1346 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1347 if (!sourceType)
1348 return failure();
1349
1350 unsigned sourceRank = sourceType.getRank();
1351 IndexType indexType = IndexType::get(context);
1352 auto memrefType =
1353 MemRefType::get({}, sourceType.getElementType(),
1354 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1355 // Base.
1356 inferredReturnTypes.push_back(memrefType);
1357 // Offset.
1358 inferredReturnTypes.push_back(indexType);
1359 // Sizes and strides.
1360 for (unsigned i = 0; i < sourceRank * 2; ++i)
1361 inferredReturnTypes.push_back(indexType);
1362 return success();
1363}
1364
1365void ExtractStridedMetadataOp::getAsmResultNames(
1366 function_ref<void(Value, StringRef)> setNameFn) {
1367 setNameFn(getBaseBuffer(), "base_buffer");
1368 setNameFn(getOffset(), "offset");
1369 // For multi-result to work properly with pretty names and packed syntax `x:3`
1370 // we can only give a pretty name to the first value in the pack.
1371 if (!getSizes().empty()) {
1372 setNameFn(getSizes().front(), "sizes");
1373 setNameFn(getStrides().front(), "strides");
1374 }
1375}
1376
1377/// Helper function to perform the replacement of all constant uses of `values`
1378/// by a materialized constant extracted from `maybeConstants`.
1379/// `values` and `maybeConstants` are expected to have the same size.
1380template <typename Container>
1381static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1382 Container values,
1383 ArrayRef<OpFoldResult> maybeConstants) {
1384 assert(values.size() == maybeConstants.size() &&
1385 " expected values and maybeConstants of the same size");
1386 bool atLeastOneReplacement = false;
1387 for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1388 // Don't materialize a constant if there are no uses: this would indice
1389 // infinite loops in the driver.
1390 if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1391 continue;
1392 assert(maybeConstant.template is<Attribute>() &&
1393 "The constified value should be either unchanged (i.e., == result) "
1394 "or a constant");
1395 Value constantVal = rewriter.create<arith::ConstantIndexOp>(
1396 loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
1397 .getInt());
1398 for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1399 // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1400 // yet.
1401 op->replaceUsesOfWith(from: result, to: constantVal);
1402 atLeastOneReplacement = true;
1403 }
1404 }
1405 return atLeastOneReplacement;
1406}
1407
1408LogicalResult
1409ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1410 SmallVectorImpl<OpFoldResult> &results) {
1411 OpBuilder builder(*this);
1412
1413 bool atLeastOneReplacement = replaceConstantUsesOf(
1414 builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1415 getConstifiedMixedOffset());
1416 atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1417 getConstifiedMixedSizes());
1418 atLeastOneReplacement |= replaceConstantUsesOf(
1419 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1420
1421 return success(atLeastOneReplacement);
1422}
1423
1424SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1425 SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1426 constifyIndexValues(values, getSource().getType(), getContext(),
1427 getConstantSizes, ShapedType::isDynamic);
1428 return values;
1429}
1430
1431SmallVector<OpFoldResult>
1432ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1433 SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1434 constifyIndexValues(values, getSource().getType(), getContext(),
1435 getConstantStrides, ShapedType::isDynamic);
1436 return values;
1437}
1438
1439OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1440 OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1441 SmallVector<OpFoldResult> values(1, offsetOfr);
1442 constifyIndexValues(values, getSource().getType(), getContext(),
1443 getConstantOffset, ShapedType::isDynamic);
1444 return values[0];
1445}
1446
1447//===----------------------------------------------------------------------===//
1448// GenericAtomicRMWOp
1449//===----------------------------------------------------------------------===//
1450
1451void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1452 Value memref, ValueRange ivs) {
1453 OpBuilder::InsertionGuard g(builder);
1454 result.addOperands(memref);
1455 result.addOperands(ivs);
1456
1457 if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1458 Type elementType = memrefType.getElementType();
1459 result.addTypes(elementType);
1460
1461 Region *bodyRegion = result.addRegion();
1462 builder.createBlock(bodyRegion);
1463 bodyRegion->addArgument(elementType, memref.getLoc());
1464 }
1465}
1466
1467LogicalResult GenericAtomicRMWOp::verify() {
1468 auto &body = getRegion();
1469 if (body.getNumArguments() != 1)
1470 return emitOpError("expected single number of entry block arguments");
1471
1472 if (getResult().getType() != body.getArgument(0).getType())
1473 return emitOpError("expected block argument of the same type result type");
1474
1475 bool hasSideEffects =
1476 body.walk([&](Operation *nestedOp) {
1477 if (isMemoryEffectFree(nestedOp))
1478 return WalkResult::advance();
1479 nestedOp->emitError(
1480 "body of 'memref.generic_atomic_rmw' should contain "
1481 "only operations with no side effects");
1482 return WalkResult::interrupt();
1483 })
1484 .wasInterrupted();
1485 return hasSideEffects ? failure() : success();
1486}
1487
1488ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1489 OperationState &result) {
1490 OpAsmParser::UnresolvedOperand memref;
1491 Type memrefType;
1492 SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
1493
1494 Type indexType = parser.getBuilder().getIndexType();
1495 if (parser.parseOperand(memref) ||
1496 parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
1497 parser.parseColonType(memrefType) ||
1498 parser.resolveOperand(memref, memrefType, result.operands) ||
1499 parser.resolveOperands(ivs, indexType, result.operands))
1500 return failure();
1501
1502 Region *body = result.addRegion();
1503 if (parser.parseRegion(*body, {}) ||
1504 parser.parseOptionalAttrDict(result.attributes))
1505 return failure();
1506 result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1507 return success();
1508}
1509
1510void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
1511 p << ' ' << getMemref() << "[" << getIndices()
1512 << "] : " << getMemref().getType() << ' ';
1513 p.printRegion(getRegion());
1514 p.printOptionalAttrDict((*this)->getAttrs());
1515}
1516
1517//===----------------------------------------------------------------------===//
1518// AtomicYieldOp
1519//===----------------------------------------------------------------------===//
1520
1521LogicalResult AtomicYieldOp::verify() {
1522 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1523 Type resultType = getResult().getType();
1524 if (parentType != resultType)
1525 return emitOpError() << "types mismatch between yield op: " << resultType
1526 << " and its parent: " << parentType;
1527 return success();
1528}
1529
1530//===----------------------------------------------------------------------===//
1531// GlobalOp
1532//===----------------------------------------------------------------------===//
1533
1534static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1535 TypeAttr type,
1536 Attribute initialValue) {
1537 p << type;
1538 if (!op.isExternal()) {
1539 p << " = ";
1540 if (op.isUninitialized())
1541 p << "uninitialized";
1542 else
1543 p.printAttributeWithoutType(attr: initialValue);
1544 }
1545}
1546
1547static ParseResult
1548parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1549 Attribute &initialValue) {
1550 Type type;
1551 if (parser.parseType(result&: type))
1552 return failure();
1553
1554 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1555 if (!memrefType || !memrefType.hasStaticShape())
1556 return parser.emitError(loc: parser.getNameLoc())
1557 << "type should be static shaped memref, but got " << type;
1558 typeAttr = TypeAttr::get(type);
1559
1560 if (parser.parseOptionalEqual())
1561 return success();
1562
1563 if (succeeded(result: parser.parseOptionalKeyword(keyword: "uninitialized"))) {
1564 initialValue = UnitAttr::get(parser.getContext());
1565 return success();
1566 }
1567
1568 Type tensorType = getTensorTypeFromMemRefType(memrefType);
1569 if (parser.parseAttribute(result&: initialValue, type: tensorType))
1570 return failure();
1571 if (!llvm::isa<ElementsAttr>(Val: initialValue))
1572 return parser.emitError(loc: parser.getNameLoc())
1573 << "initial value should be a unit or elements attribute";
1574 return success();
1575}
1576
1577LogicalResult GlobalOp::verify() {
1578 auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1579 if (!memrefType || !memrefType.hasStaticShape())
1580 return emitOpError("type should be static shaped memref, but got ")
1581 << getType();
1582
1583 // Verify that the initial value, if present, is either a unit attribute or
1584 // an elements attribute.
1585 if (getInitialValue().has_value()) {
1586 Attribute initValue = getInitialValue().value();
1587 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1588 return emitOpError("initial value should be a unit or elements "
1589 "attribute, but got ")
1590 << initValue;
1591
1592 // Check that the type of the initial value is compatible with the type of
1593 // the global variable.
1594 if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1595 Type initType = elementsAttr.getType();
1596 Type tensorType = getTensorTypeFromMemRefType(memrefType);
1597 if (initType != tensorType)
1598 return emitOpError("initial value expected to be of type ")
1599 << tensorType << ", but was of type " << initType;
1600 }
1601 }
1602
1603 if (std::optional<uint64_t> alignAttr = getAlignment()) {
1604 uint64_t alignment = *alignAttr;
1605
1606 if (!llvm::isPowerOf2_64(alignment))
1607 return emitError() << "alignment attribute value " << alignment
1608 << " is not a power of 2";
1609 }
1610
1611 // TODO: verify visibility for declarations.
1612 return success();
1613}
1614
1615ElementsAttr GlobalOp::getConstantInitValue() {
1616 auto initVal = getInitialValue();
1617 if (getConstant() && initVal.has_value())
1618 return llvm::cast<ElementsAttr>(initVal.value());
1619 return {};
1620}
1621
1622//===----------------------------------------------------------------------===//
1623// GetGlobalOp
1624//===----------------------------------------------------------------------===//
1625
1626LogicalResult
1627GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1628 // Verify that the result type is same as the type of the referenced
1629 // memref.global op.
1630 auto global =
1631 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1632 if (!global)
1633 return emitOpError("'")
1634 << getName() << "' does not reference a valid global memref";
1635
1636 Type resultType = getResult().getType();
1637 if (global.getType() != resultType)
1638 return emitOpError("result type ")
1639 << resultType << " does not match type " << global.getType()
1640 << " of the global memref @" << getName();
1641 return success();
1642}
1643
1644//===----------------------------------------------------------------------===//
1645// LoadOp
1646//===----------------------------------------------------------------------===//
1647
1648LogicalResult LoadOp::verify() {
1649 if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1650 return emitOpError("incorrect number of indices for load, expected ")
1651 << getMemRefType().getRank() << " but got " << getIndices().size();
1652 }
1653 return success();
1654}
1655
1656OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1657 /// load(memrefcast) -> load
1658 if (succeeded(foldMemRefCast(*this)))
1659 return getResult();
1660 return OpFoldResult();
1661}
1662
1663//===----------------------------------------------------------------------===//
1664// MemorySpaceCastOp
1665//===----------------------------------------------------------------------===//
1666
1667void MemorySpaceCastOp::getAsmResultNames(
1668 function_ref<void(Value, StringRef)> setNameFn) {
1669 setNameFn(getResult(), "memspacecast");
1670}
1671
1672bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1673 if (inputs.size() != 1 || outputs.size() != 1)
1674 return false;
1675 Type a = inputs.front(), b = outputs.front();
1676 auto aT = llvm::dyn_cast<MemRefType>(a);
1677 auto bT = llvm::dyn_cast<MemRefType>(b);
1678
1679 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1680 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1681
1682 if (aT && bT) {
1683 if (aT.getElementType() != bT.getElementType())
1684 return false;
1685 if (aT.getLayout() != bT.getLayout())
1686 return false;
1687 if (aT.getShape() != bT.getShape())
1688 return false;
1689 return true;
1690 }
1691 if (uaT && ubT) {
1692 return uaT.getElementType() == ubT.getElementType();
1693 }
1694 return false;
1695}
1696
1697OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1698 // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1699 // t2)
1700 if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1701 getSourceMutable().assign(parentCast.getSource());
1702 return getResult();
1703 }
1704 return Value{};
1705}
1706
1707//===----------------------------------------------------------------------===//
1708// PrefetchOp
1709//===----------------------------------------------------------------------===//
1710
1711void PrefetchOp::print(OpAsmPrinter &p) {
1712 p << " " << getMemref() << '[';
1713 p.printOperands(getIndices());
1714 p << ']' << ", " << (getIsWrite() ? "write" : "read");
1715 p << ", locality<" << getLocalityHint();
1716 p << ">, " << (getIsDataCache() ? "data" : "instr");
1717 p.printOptionalAttrDict(
1718 (*this)->getAttrs(),
1719 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1720 p << " : " << getMemRefType();
1721}
1722
1723ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1724 OpAsmParser::UnresolvedOperand memrefInfo;
1725 SmallVector<OpAsmParser::UnresolvedOperand, 4> indexInfo;
1726 IntegerAttr localityHint;
1727 MemRefType type;
1728 StringRef readOrWrite, cacheType;
1729
1730 auto indexTy = parser.getBuilder().getIndexType();
1731 auto i32Type = parser.getBuilder().getIntegerType(32);
1732 if (parser.parseOperand(memrefInfo) ||
1733 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1734 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1735 parser.parseComma() || parser.parseKeyword("locality") ||
1736 parser.parseLess() ||
1737 parser.parseAttribute(localityHint, i32Type, "localityHint",
1738 result.attributes) ||
1739 parser.parseGreater() || parser.parseComma() ||
1740 parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1741 parser.resolveOperand(memrefInfo, type, result.operands) ||
1742 parser.resolveOperands(indexInfo, indexTy, result.operands))
1743 return failure();
1744
1745 if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1746 return parser.emitError(parser.getNameLoc(),
1747 "rw specifier has to be 'read' or 'write'");
1748 result.addAttribute(
1749 PrefetchOp::getIsWriteAttrStrName(),
1750 parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1751
1752 if (!cacheType.equals("data") && !cacheType.equals("instr"))
1753 return parser.emitError(parser.getNameLoc(),
1754 "cache type has to be 'data' or 'instr'");
1755
1756 result.addAttribute(
1757 PrefetchOp::getIsDataCacheAttrStrName(),
1758 parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1759
1760 return success();
1761}
1762
1763LogicalResult PrefetchOp::verify() {
1764 if (getNumOperands() != 1 + getMemRefType().getRank())
1765 return emitOpError("too few indices");
1766
1767 return success();
1768}
1769
1770LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1771 SmallVectorImpl<OpFoldResult> &results) {
1772 // prefetch(memrefcast) -> prefetch
1773 return foldMemRefCast(*this);
1774}
1775
1776//===----------------------------------------------------------------------===//
1777// RankOp
1778//===----------------------------------------------------------------------===//
1779
1780OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1781 // Constant fold rank when the rank of the operand is known.
1782 auto type = getOperand().getType();
1783 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1784 if (shapedType && shapedType.hasRank())
1785 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1786 return IntegerAttr();
1787}
1788
1789//===----------------------------------------------------------------------===//
1790// ReinterpretCastOp
1791//===----------------------------------------------------------------------===//
1792
1793void ReinterpretCastOp::getAsmResultNames(
1794 function_ref<void(Value, StringRef)> setNameFn) {
1795 setNameFn(getResult(), "reinterpret_cast");
1796}
1797
1798/// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1799/// `staticSizes` and `staticStrides` are automatically filled with
1800/// source-memref-rank sentinel values that encode dynamic entries.
1801void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1802 MemRefType resultType, Value source,
1803 OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1804 ArrayRef<OpFoldResult> strides,
1805 ArrayRef<NamedAttribute> attrs) {
1806 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1807 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1808 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1809 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1810 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1811 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1812 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1813 b.getDenseI64ArrayAttr(staticSizes),
1814 b.getDenseI64ArrayAttr(staticStrides));
1815 result.addAttributes(attrs);
1816}
1817
1818void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1819 MemRefType resultType, Value source,
1820 int64_t offset, ArrayRef<int64_t> sizes,
1821 ArrayRef<int64_t> strides,
1822 ArrayRef<NamedAttribute> attrs) {
1823 SmallVector<OpFoldResult> sizeValues =
1824 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1825 return b.getI64IntegerAttr(v);
1826 }));
1827 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1828 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1829 return b.getI64IntegerAttr(v);
1830 }));
1831 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1832 strideValues, attrs);
1833}
1834
1835void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1836 MemRefType resultType, Value source, Value offset,
1837 ValueRange sizes, ValueRange strides,
1838 ArrayRef<NamedAttribute> attrs) {
1839 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1840 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1841 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1842 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1843 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1844}
1845
1846// TODO: ponder whether we want to allow missing trailing sizes/strides that are
1847// completed automatically, like we have for subview and extract_slice.
1848LogicalResult ReinterpretCastOp::verify() {
1849 // The source and result memrefs should be in the same memory space.
1850 auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1851 auto resultType = llvm::cast<MemRefType>(getType());
1852 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1853 return emitError("different memory spaces specified for source type ")
1854 << srcType << " and result memref type " << resultType;
1855 if (srcType.getElementType() != resultType.getElementType())
1856 return emitError("different element types specified for source type ")
1857 << srcType << " and result memref type " << resultType;
1858
1859 // Match sizes in result memref type and in static_sizes attribute.
1860 for (auto [idx, resultSize, expectedSize] :
1861 llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1862 if (!ShapedType::isDynamic(resultSize) &&
1863 !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1864 return emitError("expected result type with size = ")
1865 << expectedSize << " instead of " << resultSize
1866 << " in dim = " << idx;
1867 }
1868
1869 // Match offset and strides in static_offset and static_strides attributes. If
1870 // result memref type has no affine map specified, this will assume an
1871 // identity layout.
1872 int64_t resultOffset;
1873 SmallVector<int64_t, 4> resultStrides;
1874 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1875 return emitError("expected result type to have strided layout but found ")
1876 << resultType;
1877
1878 // Match offset in result memref type and in static_offsets attribute.
1879 int64_t expectedOffset = getStaticOffsets().front();
1880 if (!ShapedType::isDynamic(resultOffset) &&
1881 !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
1882 return emitError("expected result type with offset = ")
1883 << expectedOffset << " instead of " << resultOffset;
1884
1885 // Match strides in result memref type and in static_strides attribute.
1886 for (auto [idx, resultStride, expectedStride] :
1887 llvm::enumerate(resultStrides, getStaticStrides())) {
1888 if (!ShapedType::isDynamic(resultStride) &&
1889 !ShapedType::isDynamic(expectedStride) &&
1890 resultStride != expectedStride)
1891 return emitError("expected result type with stride = ")
1892 << expectedStride << " instead of " << resultStride
1893 << " in dim = " << idx;
1894 }
1895
1896 return success();
1897}
1898
1899OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1900 Value src = getSource();
1901 auto getPrevSrc = [&]() -> Value {
1902 // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1903 if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1904 return prev.getSource();
1905
1906 // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1907 if (auto prev = src.getDefiningOp<CastOp>())
1908 return prev.getSource();
1909
1910 // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1911 // are 0.
1912 if (auto prev = src.getDefiningOp<SubViewOp>())
1913 if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1914 return isConstantIntValue(val, 0);
1915 }))
1916 return prev.getSource();
1917
1918 return nullptr;
1919 };
1920
1921 if (auto prevSrc = getPrevSrc()) {
1922 getSourceMutable().assign(prevSrc);
1923 return getResult();
1924 }
1925
1926 // reinterpret_cast(x) w/o offset/shape/stride changes -> x
1927 if (!ShapedType::isDynamicShape(getType().getShape()) &&
1928 src.getType() == getType() && getStaticOffsets().front() == 0) {
1929 return src;
1930 }
1931
1932 return nullptr;
1933}
1934
1935SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
1936 SmallVector<OpFoldResult> values = getMixedSizes();
1937 constifyIndexValues(values, getType(), getContext(), getConstantSizes,
1938 ShapedType::isDynamic);
1939 return values;
1940}
1941
1942SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
1943 SmallVector<OpFoldResult> values = getMixedStrides();
1944 constifyIndexValues(values, getType(), getContext(), getConstantStrides,
1945 ShapedType::isDynamic);
1946 return values;
1947}
1948
1949OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1950 SmallVector<OpFoldResult> values = getMixedOffsets();
1951 assert(values.size() == 1 &&
1952 "reinterpret_cast must have one and only one offset");
1953 constifyIndexValues(values, getType(), getContext(), getConstantOffset,
1954 ShapedType::isDynamic);
1955 return values[0];
1956}
1957
1958namespace {
1959/// Replace the sequence:
1960/// ```
1961/// base, offset, sizes, strides = extract_strided_metadata src
1962/// dst = reinterpret_cast base to offset, sizes, strides
1963/// ```
1964/// With
1965///
1966/// ```
1967/// dst = memref.cast src
1968/// ```
1969///
1970/// Note: The cast operation is only inserted when the type of dst and src
1971/// are not the same. E.g., when going from <4xf32> to <?xf32>.
1972///
1973/// This pattern also matches when the offset, sizes, and strides don't come
1974/// directly from the `extract_strided_metadata`'s results but it can be
1975/// statically proven that they would hold the same values.
1976///
1977/// For instance, the following sequence would be replaced:
1978/// ```
1979/// base, offset, sizes, strides =
1980/// extract_strided_metadata memref : memref<3x4xty>
1981/// dst = reinterpret_cast base to 0, [3, 4], strides
1982/// ```
1983/// Because we know (thanks to the type of the input memref) that variable
1984/// `offset` and `sizes` will respectively hold 0 and [3, 4].
1985///
1986/// Similarly, the following sequence would be replaced:
1987/// ```
1988/// c0 = arith.constant 0
1989/// c4 = arith.constant 4
1990/// base, offset, sizes, strides =
1991/// extract_strided_metadata memref : memref<3x4xty>
1992/// dst = reinterpret_cast base to c0, [3, c4], strides
1993/// ```
1994/// Because we know that `offset`and `c0` will hold 0
1995/// and `c4` will hold 4.
1996struct ReinterpretCastOpExtractStridedMetadataFolder
1997 : public OpRewritePattern<ReinterpretCastOp> {
1998public:
1999 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2000
2001 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2002 PatternRewriter &rewriter) const override {
2003 auto extractStridedMetadata =
2004 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2005 if (!extractStridedMetadata)
2006 return failure();
2007 // Check if the reinterpret cast reconstructs a memref with the exact same
2008 // properties as the extract strided metadata.
2009
2010 // First, check that the strides are the same.
2011 SmallVector<OpFoldResult> extractStridesOfr =
2012 extractStridedMetadata.getConstifiedMixedStrides();
2013 SmallVector<OpFoldResult> reinterpretStridesOfr =
2014 op.getConstifiedMixedStrides();
2015 if (extractStridesOfr.size() != reinterpretStridesOfr.size())
2016 return failure();
2017
2018 unsigned rank = op.getType().getRank();
2019 for (unsigned i = 0; i < rank; ++i) {
2020 if (extractStridesOfr[i] != reinterpretStridesOfr[i])
2021 return failure();
2022 }
2023
2024 // Second, check the sizes.
2025 assert(extractStridedMetadata.getSizes().size() ==
2026 op.getMixedSizes().size() &&
2027 "Strides and sizes rank must match");
2028 SmallVector<OpFoldResult> extractSizesOfr =
2029 extractStridedMetadata.getConstifiedMixedSizes();
2030 SmallVector<OpFoldResult> reinterpretSizesOfr =
2031 op.getConstifiedMixedSizes();
2032 for (unsigned i = 0; i < rank; ++i) {
2033 if (extractSizesOfr[i] != reinterpretSizesOfr[i])
2034 return failure();
2035 }
2036 // Finally, check the offset.
2037 assert(op.getMixedOffsets().size() == 1 &&
2038 "reinterpret_cast with more than one offset should have been "
2039 "rejected by the verifier");
2040 OpFoldResult extractOffsetOfr =
2041 extractStridedMetadata.getConstifiedMixedOffset();
2042 OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
2043 if (extractOffsetOfr != reinterpretOffsetOfr)
2044 return failure();
2045
2046 // At this point, we know that the back and forth between extract strided
2047 // metadata and reinterpret cast is a noop. However, the final type of the
2048 // reinterpret cast may not be exactly the same as the original memref.
2049 // E.g., it could be changing a dimension from static to dynamic. Check that
2050 // here and add a cast if necessary.
2051 Type srcTy = extractStridedMetadata.getSource().getType();
2052 if (srcTy == op.getResult().getType())
2053 rewriter.replaceOp(op, extractStridedMetadata.getSource());
2054 else
2055 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2056 extractStridedMetadata.getSource());
2057
2058 return success();
2059 }
2060};
2061} // namespace
2062
2063void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2064 MLIRContext *context) {
2065 results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2066}
2067
2068//===----------------------------------------------------------------------===//
2069// Reassociative reshape ops
2070//===----------------------------------------------------------------------===//
2071
2072void CollapseShapeOp::getAsmResultNames(
2073 function_ref<void(Value, StringRef)> setNameFn) {
2074 setNameFn(getResult(), "collapse_shape");
2075}
2076
2077void ExpandShapeOp::getAsmResultNames(
2078 function_ref<void(Value, StringRef)> setNameFn) {
2079 setNameFn(getResult(), "expand_shape");
2080}
2081
2082/// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2083/// result and operand. Layout maps are verified separately.
2084///
2085/// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2086/// allowed in a reassocation group.
2087static LogicalResult
2088verifyCollapsedShape(Operation *op, ArrayRef<int64_t> collapsedShape,
2089 ArrayRef<int64_t> expandedShape,
2090 ArrayRef<ReassociationIndices> reassociation,
2091 bool allowMultipleDynamicDimsPerGroup) {
2092 // There must be one reassociation group per collapsed dimension.
2093 if (collapsedShape.size() != reassociation.size())
2094 return op->emitOpError(message: "invalid number of reassociation groups: found ")
2095 << reassociation.size() << ", expected " << collapsedShape.size();
2096
2097 // The next expected expanded dimension index (while iterating over
2098 // reassociation indices).
2099 int64_t nextDim = 0;
2100 for (const auto &it : llvm::enumerate(First&: reassociation)) {
2101 ReassociationIndices group = it.value();
2102 int64_t collapsedDim = it.index();
2103
2104 bool foundDynamic = false;
2105 for (int64_t expandedDim : group) {
2106 if (expandedDim != nextDim++)
2107 return op->emitOpError(message: "reassociation indices must be contiguous");
2108
2109 if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2110 return op->emitOpError(message: "reassociation index ")
2111 << expandedDim << " is out of bounds";
2112
2113 // Check if there are multiple dynamic dims in a reassociation group.
2114 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2115 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2116 return op->emitOpError(
2117 message: "at most one dimension in a reassociation group may be dynamic");
2118 foundDynamic = true;
2119 }
2120 }
2121
2122 // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2123 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2124 return op->emitOpError(message: "collapsed dim (")
2125 << collapsedDim
2126 << ") must be dynamic if and only if reassociation group is "
2127 "dynamic";
2128
2129 // If all dims in the reassociation group are static, the size of the
2130 // collapsed dim can be verified.
2131 if (!foundDynamic) {
2132 int64_t groupSize = 1;
2133 for (int64_t expandedDim : group)
2134 groupSize *= expandedShape[expandedDim];
2135 if (groupSize != collapsedShape[collapsedDim])
2136 return op->emitOpError(message: "collapsed dim size (")
2137 << collapsedShape[collapsedDim]
2138 << ") must equal reassociation group size (" << groupSize << ")";
2139 }
2140 }
2141
2142 if (collapsedShape.empty()) {
2143 // Rank 0: All expanded dimensions must be 1.
2144 for (int64_t d : expandedShape)
2145 if (d != 1)
2146 return op->emitOpError(
2147 message: "rank 0 memrefs can only be extended/collapsed with/from ones");
2148 } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2149 // Rank >= 1: Number of dimensions among all reassociation groups must match
2150 // the result memref rank.
2151 return op->emitOpError(message: "expanded rank (")
2152 << expandedShape.size()
2153 << ") inconsistent with number of reassociation indices (" << nextDim
2154 << ")";
2155 }
2156
2157 return success();
2158}
2159
2160SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2161 return getSymbolLessAffineMaps(getReassociationExprs());
2162}
2163
2164SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2165 return convertReassociationIndicesToExprs(getContext(),
2166 getReassociationIndices());
2167}
2168
2169SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2170 return getSymbolLessAffineMaps(getReassociationExprs());
2171}
2172
2173SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2174 return convertReassociationIndicesToExprs(getContext(),
2175 getReassociationIndices());
2176}
2177
2178/// Compute the layout map after expanding a given source MemRef type with the
2179/// specified reassociation indices.
2180static FailureOr<StridedLayoutAttr>
2181computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2182 ArrayRef<ReassociationIndices> reassociation) {
2183 int64_t srcOffset;
2184 SmallVector<int64_t> srcStrides;
2185 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2186 return failure();
2187 assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2188
2189 // 1-1 mapping between srcStrides and reassociation packs.
2190 // Each srcStride starts with the given value and gets expanded according to
2191 // the proper entries in resultShape.
2192 // Example:
2193 // srcStrides = [10000, 1 , 100 ],
2194 // reassociations = [ [0], [1], [2, 3, 4]],
2195 // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2196 // -> For the purpose of stride calculation, the useful sizes are:
2197 // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2198 // resultStrides = [10000, 1, 600, 200, 100]
2199 // Note that a stride does not get expanded along the first entry of each
2200 // shape pack.
2201 SmallVector<int64_t> reverseResultStrides;
2202 reverseResultStrides.reserve(N: resultShape.size());
2203 unsigned shapeIndex = resultShape.size() - 1;
2204 for (auto it : llvm::reverse(C: llvm::zip(t&: reassociation, u&: srcStrides))) {
2205 ReassociationIndices reassoc = std::get<0>(t&: it);
2206 int64_t currentStrideToExpand = std::get<1>(t&: it);
2207 for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2208 reverseResultStrides.push_back(Elt: currentStrideToExpand);
2209 currentStrideToExpand =
2210 (SaturatedInteger::wrap(v: currentStrideToExpand) *
2211 SaturatedInteger::wrap(v: resultShape[shapeIndex--]))
2212 .asInteger();
2213 }
2214 }
2215 auto resultStrides = llvm::to_vector<8>(Range: llvm::reverse(C&: reverseResultStrides));
2216 resultStrides.resize(N: resultShape.size(), NV: 1);
2217 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2218}
2219
2220FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2221 MemRefType srcType, ArrayRef<int64_t> resultShape,
2222 ArrayRef<ReassociationIndices> reassociation) {
2223 if (srcType.getLayout().isIdentity()) {
2224 // If the source is contiguous (i.e., no layout map specified), so is the
2225 // result.
2226 MemRefLayoutAttrInterface layout;
2227 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2228 srcType.getMemorySpace());
2229 }
2230
2231 // Source may not be contiguous. Compute the layout map.
2232 FailureOr<StridedLayoutAttr> computedLayout =
2233 computeExpandedLayoutMap(srcType, resultShape, reassociation);
2234 if (failed(computedLayout))
2235 return failure();
2236 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2237 srcType.getMemorySpace());
2238}
2239
2240void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2241 ArrayRef<int64_t> resultShape, Value src,
2242 ArrayRef<ReassociationIndices> reassociation) {
2243 // Only ranked memref source values are supported.
2244 auto srcType = llvm::cast<MemRefType>(src.getType());
2245 FailureOr<MemRefType> resultType =
2246 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2247 // Failure of this assertion usually indicates a problem with the source
2248 // type, e.g., could not get strides/offset.
2249 assert(succeeded(resultType) && "could not compute layout");
2250 build(builder, result, *resultType, src, reassociation);
2251}
2252
2253LogicalResult ExpandShapeOp::verify() {
2254 MemRefType srcType = getSrcType();
2255 MemRefType resultType = getResultType();
2256
2257 if (srcType.getRank() > resultType.getRank()) {
2258 auto r0 = srcType.getRank();
2259 auto r1 = resultType.getRank();
2260 return emitOpError("has source rank ")
2261 << r0 << " and result rank " << r1 << ". This is not an expansion ("
2262 << r0 << " > " << r1 << ").";
2263 }
2264
2265 // Verify result shape.
2266 if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2267 resultType.getShape(),
2268 getReassociationIndices(),
2269 /*allowMultipleDynamicDimsPerGroup=*/false)))
2270 return failure();
2271
2272 // Compute expected result type (including layout map).
2273 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2274 srcType, resultType.getShape(), getReassociationIndices());
2275 if (failed(expectedResultType))
2276 return emitOpError("invalid source layout map");
2277
2278 // Check actual result type.
2279 if (*expectedResultType != resultType)
2280 return emitOpError("expected expanded type to be ")
2281 << *expectedResultType << " but found " << resultType;
2282
2283 return success();
2284}
2285
2286void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2287 MLIRContext *context) {
2288 results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
2289 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(
2290 context);
2291}
2292
2293/// Compute the layout map after collapsing a given source MemRef type with the
2294/// specified reassociation indices.
2295///
2296/// Note: All collapsed dims in a reassociation group must be contiguous. It is
2297/// not possible to check this by inspecting a MemRefType in the general case.
2298/// If non-contiguity cannot be checked statically, the collapse is assumed to
2299/// be valid (and thus accepted by this function) unless `strict = true`.
2300static FailureOr<StridedLayoutAttr>
2301computeCollapsedLayoutMap(MemRefType srcType,
2302 ArrayRef<ReassociationIndices> reassociation,
2303 bool strict = false) {
2304 int64_t srcOffset;
2305 SmallVector<int64_t> srcStrides;
2306 auto srcShape = srcType.getShape();
2307 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2308 return failure();
2309
2310 // The result stride of a reassociation group is the stride of the last entry
2311 // of the reassociation. (TODO: Should be the minimum stride in the
2312 // reassociation because strides are not necessarily sorted. E.g., when using
2313 // memref.transpose.) Dimensions of size 1 should be skipped, because their
2314 // strides are meaningless and could have any arbitrary value.
2315 SmallVector<int64_t> resultStrides;
2316 resultStrides.reserve(N: reassociation.size());
2317 for (const ReassociationIndices &reassoc : reassociation) {
2318 ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2319 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2320 ref = ref.drop_back();
2321 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2322 resultStrides.push_back(Elt: srcStrides[ref.back()]);
2323 } else {
2324 // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2325 // the corresponding stride may have to be skipped. (See above comment.)
2326 // Therefore, the result stride cannot be statically determined and must
2327 // be dynamic.
2328 resultStrides.push_back(ShapedType::kDynamic);
2329 }
2330 }
2331
2332 // Validate that each reassociation group is contiguous.
2333 unsigned resultStrideIndex = resultStrides.size() - 1;
2334 for (const ReassociationIndices &reassoc : llvm::reverse(C&: reassociation)) {
2335 auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2336 auto stride = SaturatedInteger::wrap(v: resultStrides[resultStrideIndex--]);
2337 for (int64_t idx : llvm::reverse(C&: trailingReassocs)) {
2338 stride = stride * SaturatedInteger::wrap(v: srcShape[idx]);
2339
2340 // Both source and result stride must have the same static value. In that
2341 // case, we can be sure, that the dimensions are collapsible (because they
2342 // are contiguous).
2343 // If `strict = false` (default during op verification), we accept cases
2344 // where one or both strides are dynamic. This is best effort: We reject
2345 // ops where obviously non-contiguous dims are collapsed, but accept ops
2346 // where we cannot be sure statically. Such ops may fail at runtime. See
2347 // the op documentation for details.
2348 auto srcStride = SaturatedInteger::wrap(v: srcStrides[idx - 1]);
2349 if (strict && (stride.saturated || srcStride.saturated))
2350 return failure();
2351
2352 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2353 return failure();
2354 }
2355 }
2356 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2357}
2358
2359bool CollapseShapeOp::isGuaranteedCollapsible(
2360 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2361 // MemRefs with identity layout are always collapsible.
2362 if (srcType.getLayout().isIdentity())
2363 return true;
2364
2365 return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2366 /*strict=*/true));
2367}
2368
2369MemRefType CollapseShapeOp::computeCollapsedType(
2370 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2371 SmallVector<int64_t> resultShape;
2372 resultShape.reserve(reassociation.size());
2373 for (const ReassociationIndices &group : reassociation) {
2374 auto groupSize = SaturatedInteger::wrap(1);
2375 for (int64_t srcDim : group)
2376 groupSize =
2377 groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2378 resultShape.push_back(groupSize.asInteger());
2379 }
2380
2381 if (srcType.getLayout().isIdentity()) {
2382 // If the source is contiguous (i.e., no layout map specified), so is the
2383 // result.
2384 MemRefLayoutAttrInterface layout;
2385 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2386 srcType.getMemorySpace());
2387 }
2388
2389 // Source may not be fully contiguous. Compute the layout map.
2390 // Note: Dimensions that are collapsed into a single dim are assumed to be
2391 // contiguous.
2392 FailureOr<StridedLayoutAttr> computedLayout =
2393 computeCollapsedLayoutMap(srcType, reassociation);
2394 assert(succeeded(computedLayout) &&
2395 "invalid source layout map or collapsing non-contiguous dims");
2396 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2397 srcType.getMemorySpace());
2398}
2399
2400void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2401 ArrayRef<ReassociationIndices> reassociation,
2402 ArrayRef<NamedAttribute> attrs) {
2403 auto srcType = llvm::cast<MemRefType>(src.getType());
2404 MemRefType resultType =
2405 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2406 build(b, result, resultType, src, attrs);
2407 result.addAttribute(::mlir::getReassociationAttrName(),
2408 getReassociationIndicesAttribute(b, reassociation));
2409}
2410
2411LogicalResult CollapseShapeOp::verify() {
2412 MemRefType srcType = getSrcType();
2413 MemRefType resultType = getResultType();
2414
2415 if (srcType.getRank() < resultType.getRank()) {
2416 auto r0 = srcType.getRank();
2417 auto r1 = resultType.getRank();
2418 return emitOpError("has source rank ")
2419 << r0 << " and result rank " << r1 << ". This is not a collapse ("
2420 << r0 << " < " << r1 << ").";
2421 }
2422
2423 // Verify result shape.
2424 if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2425 srcType.getShape(), getReassociationIndices(),
2426 /*allowMultipleDynamicDimsPerGroup=*/true)))
2427 return failure();
2428
2429 // Compute expected result type (including layout map).
2430 MemRefType expectedResultType;
2431 if (srcType.getLayout().isIdentity()) {
2432 // If the source is contiguous (i.e., no layout map specified), so is the
2433 // result.
2434 MemRefLayoutAttrInterface layout;
2435 expectedResultType =
2436 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2437 srcType.getMemorySpace());
2438 } else {
2439 // Source may not be fully contiguous. Compute the layout map.
2440 // Note: Dimensions that are collapsed into a single dim are assumed to be
2441 // contiguous.
2442 FailureOr<StridedLayoutAttr> computedLayout =
2443 computeCollapsedLayoutMap(srcType, getReassociationIndices());
2444 if (failed(computedLayout))
2445 return emitOpError(
2446 "invalid source layout map or collapsing non-contiguous dims");
2447 expectedResultType =
2448 MemRefType::get(resultType.getShape(), srcType.getElementType(),
2449 *computedLayout, srcType.getMemorySpace());
2450 }
2451
2452 if (expectedResultType != resultType)
2453 return emitOpError("expected collapsed type to be ")
2454 << expectedResultType << " but found " << resultType;
2455
2456 return success();
2457}
2458
2459struct CollapseShapeOpMemRefCastFolder
2460 : public OpRewritePattern<CollapseShapeOp> {
2461public:
2462 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2463
2464 LogicalResult matchAndRewrite(CollapseShapeOp op,
2465 PatternRewriter &rewriter) const override {
2466 auto cast = op.getOperand().getDefiningOp<CastOp>();
2467 if (!cast)
2468 return failure();
2469
2470 if (!CastOp::canFoldIntoConsumerOp(cast))
2471 return failure();
2472
2473 Type newResultType = CollapseShapeOp::computeCollapsedType(
2474 llvm::cast<MemRefType>(cast.getOperand().getType()),
2475 op.getReassociationIndices());
2476
2477 if (newResultType == op.getResultType()) {
2478 rewriter.modifyOpInPlace(
2479 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2480 } else {
2481 Value newOp = rewriter.create<CollapseShapeOp>(
2482 op->getLoc(), cast.getSource(), op.getReassociationIndices());
2483 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2484 }
2485 return success();
2486 }
2487};
2488
2489void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2490 MLIRContext *context) {
2491 results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
2492 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
2493 CollapseShapeOpMemRefCastFolder>(context);
2494}
2495
2496OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2497 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2498 adaptor.getOperands());
2499}
2500
2501OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2502 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2503 adaptor.getOperands());
2504}
2505
2506//===----------------------------------------------------------------------===//
2507// ReshapeOp
2508//===----------------------------------------------------------------------===//
2509
2510void ReshapeOp::getAsmResultNames(
2511 function_ref<void(Value, StringRef)> setNameFn) {
2512 setNameFn(getResult(), "reshape");
2513}
2514
2515LogicalResult ReshapeOp::verify() {
2516 Type operandType = getSource().getType();
2517 Type resultType = getResult().getType();
2518
2519 Type operandElementType =
2520 llvm::cast<ShapedType>(operandType).getElementType();
2521 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2522 if (operandElementType != resultElementType)
2523 return emitOpError("element types of source and destination memref "
2524 "types should be the same");
2525
2526 if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2527 if (!operandMemRefType.getLayout().isIdentity())
2528 return emitOpError("source memref type should have identity affine map");
2529
2530 int64_t shapeSize =
2531 llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2532 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2533 if (resultMemRefType) {
2534 if (!resultMemRefType.getLayout().isIdentity())
2535 return emitOpError("result memref type should have identity affine map");
2536 if (shapeSize == ShapedType::kDynamic)
2537 return emitOpError("cannot use shape operand with dynamic length to "
2538 "reshape to statically-ranked memref type");
2539 if (shapeSize != resultMemRefType.getRank())
2540 return emitOpError(
2541 "length of shape operand differs from the result's memref rank");
2542 }
2543 return success();
2544}
2545
2546//===----------------------------------------------------------------------===//
2547// StoreOp
2548//===----------------------------------------------------------------------===//
2549
2550LogicalResult StoreOp::verify() {
2551 if (getNumOperands() != 2 + getMemRefType().getRank())
2552 return emitOpError("store index operand count not equal to memref rank");
2553
2554 return success();
2555}
2556
2557LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2558 SmallVectorImpl<OpFoldResult> &results) {
2559 /// store(memrefcast) -> store
2560 return foldMemRefCast(*this, getValueToStore());
2561}
2562
2563//===----------------------------------------------------------------------===//
2564// SubViewOp
2565//===----------------------------------------------------------------------===//
2566
2567void SubViewOp::getAsmResultNames(
2568 function_ref<void(Value, StringRef)> setNameFn) {
2569 setNameFn(getResult(), "subview");
2570}
2571
2572/// A subview result type can be fully inferred from the source type and the
2573/// static representation of offsets, sizes and strides. Special sentinels
2574/// encode the dynamic case.
2575Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2576 ArrayRef<int64_t> staticOffsets,
2577 ArrayRef<int64_t> staticSizes,
2578 ArrayRef<int64_t> staticStrides) {
2579 unsigned rank = sourceMemRefType.getRank();
2580 (void)rank;
2581 assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2582 assert(staticSizes.size() == rank && "staticSizes length mismatch");
2583 assert(staticStrides.size() == rank && "staticStrides length mismatch");
2584
2585 // Extract source offset and strides.
2586 auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceMemRefType);
2587
2588 // Compute target offset whose value is:
2589 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2590 int64_t targetOffset = sourceOffset;
2591 for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2592 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
2593 targetOffset = (SaturatedInteger::wrap(targetOffset) +
2594 SaturatedInteger::wrap(staticOffset) *
2595 SaturatedInteger::wrap(targetStride))
2596 .asInteger();
2597 }
2598
2599 // Compute target stride whose value is:
2600 // `sourceStrides_i * staticStrides_i`.
2601 SmallVector<int64_t, 4> targetStrides;
2602 targetStrides.reserve(staticOffsets.size());
2603 for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2604 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2605 targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2606 SaturatedInteger::wrap(staticStride))
2607 .asInteger());
2608 }
2609
2610 // The type is now known.
2611 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2612 StridedLayoutAttr::get(sourceMemRefType.getContext(),
2613 targetOffset, targetStrides),
2614 sourceMemRefType.getMemorySpace());
2615}
2616
2617Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2618 ArrayRef<OpFoldResult> offsets,
2619 ArrayRef<OpFoldResult> sizes,
2620 ArrayRef<OpFoldResult> strides) {
2621 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2622 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2623 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2624 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2625 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2626 if (!hasValidSizesOffsets(staticOffsets))
2627 return {};
2628 if (!hasValidSizesOffsets(staticSizes))
2629 return {};
2630 if (!hasValidStrides(staticStrides))
2631 return {};
2632 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2633 staticSizes, staticStrides);
2634}
2635
2636Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2637 MemRefType sourceRankedTensorType,
2638 ArrayRef<int64_t> offsets,
2639 ArrayRef<int64_t> sizes,
2640 ArrayRef<int64_t> strides) {
2641 auto inferredType = llvm::cast<MemRefType>(
2642 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2643 assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2644 "expected ");
2645 if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2646 return inferredType;
2647
2648 // Compute which dimensions are dropped.
2649 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2650 computeRankReductionMask(inferredType.getShape(), resultShape);
2651 assert(dimsToProject.has_value() && "invalid rank reduction");
2652
2653 // Compute the layout and result type.
2654 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2655 SmallVector<int64_t> rankReducedStrides;
2656 rankReducedStrides.reserve(resultShape.size());
2657 for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2658 if (!dimsToProject->contains(idx))
2659 rankReducedStrides.push_back(value);
2660 }
2661 return MemRefType::get(resultShape, inferredType.getElementType(),
2662 StridedLayoutAttr::get(inferredLayout.getContext(),
2663 inferredLayout.getOffset(),
2664 rankReducedStrides),
2665 inferredType.getMemorySpace());
2666}
2667
2668Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2669 MemRefType sourceRankedTensorType,
2670 ArrayRef<OpFoldResult> offsets,
2671 ArrayRef<OpFoldResult> sizes,
2672 ArrayRef<OpFoldResult> strides) {
2673 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2674 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2675 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2676 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2677 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2678 return SubViewOp::inferRankReducedResultType(
2679 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2680 staticStrides);
2681}
2682
2683// Build a SubViewOp with mixed static and dynamic entries and custom result
2684// type. If the type passed is nullptr, it is inferred.
2685void SubViewOp::build(OpBuilder &b, OperationState &result,
2686 MemRefType resultType, Value source,
2687 ArrayRef<OpFoldResult> offsets,
2688 ArrayRef<OpFoldResult> sizes,
2689 ArrayRef<OpFoldResult> strides,
2690 ArrayRef<NamedAttribute> attrs) {
2691 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2692 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2693 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2694 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2695 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2696 auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2697 // Structuring implementation this way avoids duplication between builders.
2698 if (!resultType) {
2699 resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2700 sourceMemRefType, staticOffsets, staticSizes, staticStrides));
2701 }
2702 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2703 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2704 b.getDenseI64ArrayAttr(staticSizes),
2705 b.getDenseI64ArrayAttr(staticStrides));
2706 result.addAttributes(attrs);
2707}
2708
2709// Build a SubViewOp with mixed static and dynamic entries and inferred result
2710// type.
2711void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2712 ArrayRef<OpFoldResult> offsets,
2713 ArrayRef<OpFoldResult> sizes,
2714 ArrayRef<OpFoldResult> strides,
2715 ArrayRef<NamedAttribute> attrs) {
2716 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2717}
2718
2719// Build a SubViewOp with static entries and inferred result type.
2720void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2721 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2722 ArrayRef<int64_t> strides,
2723 ArrayRef<NamedAttribute> attrs) {
2724 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2725 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2726 return b.getI64IntegerAttr(v);
2727 }));
2728 SmallVector<OpFoldResult> sizeValues =
2729 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2730 return b.getI64IntegerAttr(v);
2731 }));
2732 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2733 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2734 return b.getI64IntegerAttr(v);
2735 }));
2736 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2737}
2738
2739// Build a SubViewOp with dynamic entries and custom result type. If the
2740// type passed is nullptr, it is inferred.
2741void SubViewOp::build(OpBuilder &b, OperationState &result,
2742 MemRefType resultType, Value source,
2743 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2744 ArrayRef<int64_t> strides,
2745 ArrayRef<NamedAttribute> attrs) {
2746 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2747 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2748 return b.getI64IntegerAttr(v);
2749 }));
2750 SmallVector<OpFoldResult> sizeValues =
2751 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2752 return b.getI64IntegerAttr(v);
2753 }));
2754 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2755 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2756 return b.getI64IntegerAttr(v);
2757 }));
2758 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2759 attrs);
2760}
2761
2762// Build a SubViewOp with dynamic entries and custom result type. If the type
2763// passed is nullptr, it is inferred.
2764void SubViewOp::build(OpBuilder &b, OperationState &result,
2765 MemRefType resultType, Value source, ValueRange offsets,
2766 ValueRange sizes, ValueRange strides,
2767 ArrayRef<NamedAttribute> attrs) {
2768 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2769 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2770 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2771 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2772 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2773 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2774 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2775}
2776
2777// Build a SubViewOp with dynamic entries and inferred result type.
2778void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2779 ValueRange offsets, ValueRange sizes, ValueRange strides,
2780 ArrayRef<NamedAttribute> attrs) {
2781 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2782}
2783
2784/// For ViewLikeOpInterface.
2785Value SubViewOp::getViewSource() { return getSource(); }
2786
2787/// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
2788/// static value).
2789static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2790 int64_t t1Offset, t2Offset;
2791 SmallVector<int64_t> t1Strides, t2Strides;
2792 auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2793 auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2794 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2795}
2796
2797/// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2798/// static value). Dimensions of `t1` may be dropped in `t2`; these must be
2799/// marked as dropped in `droppedDims`.
2800static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
2801 const llvm::SmallBitVector &droppedDims) {
2802 assert(size_t(t1.getRank()) == droppedDims.size() && "incorrect number of bits");
2803 assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2804 "incorrect number of dropped dims");
2805 int64_t t1Offset, t2Offset;
2806 SmallVector<int64_t> t1Strides, t2Strides;
2807 auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2808 auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2809 if (failed(res1) || failed(res2))
2810 return false;
2811 for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
2812 if (droppedDims[i])
2813 continue;
2814 if (t1Strides[i] != t2Strides[j])
2815 return false;
2816 ++j;
2817 }
2818 return true;
2819}
2820
2821static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
2822 Operation *op, Type expectedType) {
2823 auto memrefType = llvm::cast<ShapedType>(expectedType);
2824 switch (result) {
2825 case SliceVerificationResult::Success:
2826 return success();
2827 case SliceVerificationResult::RankTooLarge:
2828 return op->emitError(message: "expected result rank to be smaller or equal to ")
2829 << "the source rank. ";
2830 case SliceVerificationResult::SizeMismatch:
2831 return op->emitError(message: "expected result type to be ")
2832 << expectedType
2833 << " or a rank-reduced version. (mismatch of result sizes) ";
2834 case SliceVerificationResult::ElemTypeMismatch:
2835 return op->emitError(message: "expected result element type to be ")
2836 << memrefType.getElementType();
2837 case SliceVerificationResult::MemSpaceMismatch:
2838 return op->emitError(message: "expected result and source memory spaces to match.");
2839 case SliceVerificationResult::LayoutMismatch:
2840 return op->emitError(message: "expected result type to be ")
2841 << expectedType
2842 << " or a rank-reduced version. (mismatch of result layout) ";
2843 }
2844 llvm_unreachable("unexpected subview verification result");
2845}
2846
2847/// Verifier for SubViewOp.
2848LogicalResult SubViewOp::verify() {
2849 MemRefType baseType = getSourceType();
2850 MemRefType subViewType = getType();
2851
2852 // The base memref and the view memref should be in the same memory space.
2853 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2854 return emitError("different memory spaces specified for base memref "
2855 "type ")
2856 << baseType << " and subview memref type " << subViewType;
2857
2858 // Verify that the base memref type has a strided layout map.
2859 if (!isStrided(baseType))
2860 return emitError("base type ") << baseType << " is not strided";
2861
2862 // Compute the expected result type, assuming that there are no rank
2863 // reductions.
2864 auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
2865 baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
2866
2867 // Verify all properties of a shaped type: rank, element type and dimension
2868 // sizes. This takes into account potential rank reductions.
2869 auto shapedTypeVerification = isRankReducedType(
2870 /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
2871 if (shapedTypeVerification != SliceVerificationResult::Success)
2872 return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
2873
2874 // Make sure that the memory space did not change.
2875 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
2876 return produceSubViewErrorMsg(SliceVerificationResult::MemSpaceMismatch,
2877 *this, expectedType);
2878
2879 // Verify the offset of the layout map.
2880 if (!haveCompatibleOffsets(expectedType, subViewType))
2881 return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
2882 *this, expectedType);
2883
2884 // The only thing that's left to verify now are the strides. First, compute
2885 // the unused dimensions due to rank reductions. We have to look at sizes and
2886 // strides to decide which dimensions were dropped. This function also
2887 // partially verifies strides in case of rank reductions.
2888 auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
2889 getMixedSizes());
2890 if (failed(unusedDims))
2891 return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
2892 *this, expectedType);
2893
2894 // Strides must match.
2895 if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
2896 return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
2897 *this, expectedType);
2898
2899 return success();
2900}
2901
2902raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
2903 return os << "range " << range.offset << ":" << range.size << ":"
2904 << range.stride;
2905}
2906
2907/// Return the list of Range (i.e. offset, size, stride). Each Range
2908/// entry contains either the dynamic value or a ConstantIndexOp constructed
2909/// with `b` at location `loc`.
2910SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
2911 OpBuilder &b, Location loc) {
2912 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
2913 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
2914 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
2915 SmallVector<Range, 8> res;
2916 unsigned rank = ranks[0];
2917 res.reserve(N: rank);
2918 for (unsigned idx = 0; idx < rank; ++idx) {
2919 Value offset =
2920 op.isDynamicOffset(idx)
2921 ? op.getDynamicOffset(idx)
2922 : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
2923 Value size =
2924 op.isDynamicSize(idx)
2925 ? op.getDynamicSize(idx)
2926 : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
2927 Value stride =
2928 op.isDynamicStride(idx)
2929 ? op.getDynamicStride(idx)
2930 : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
2931 res.emplace_back(Args: Range{.offset: offset, .size: size, .stride: stride});
2932 }
2933 return res;
2934}
2935
2936/// Compute the canonical result type of a SubViewOp. Call `inferResultType`
2937/// to deduce the result type for the given `sourceType`. Additionally, reduce
2938/// the rank of the inferred result type if `currentResultType` is lower rank
2939/// than `currentSourceType`. Use this signature if `sourceType` is updated
2940/// together with the result type. In this case, it is important to compute
2941/// the dropped dimensions using `currentSourceType` whose strides align with
2942/// `currentResultType`.
2943static MemRefType getCanonicalSubViewResultType(
2944 MemRefType currentResultType, MemRefType currentSourceType,
2945 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
2946 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
2947 auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2948 sourceType, mixedOffsets, mixedSizes, mixedStrides));
2949 FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
2950 currentSourceType, currentResultType, mixedSizes);
2951 if (failed(result: unusedDims))
2952 return nullptr;
2953
2954 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
2955 SmallVector<int64_t> shape, strides;
2956 unsigned numDimsAfterReduction =
2957 nonRankReducedType.getRank() - unusedDims->count();
2958 shape.reserve(N: numDimsAfterReduction);
2959 strides.reserve(N: numDimsAfterReduction);
2960 for (const auto &[idx, size, stride] :
2961 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
2962 nonRankReducedType.getShape(), layout.getStrides())) {
2963 if (unusedDims->test(idx))
2964 continue;
2965 shape.push_back(size);
2966 strides.push_back(stride);
2967 }
2968
2969 return MemRefType::get(shape, nonRankReducedType.getElementType(),
2970 StridedLayoutAttr::get(sourceType.getContext(),
2971 layout.getOffset(), strides),
2972 nonRankReducedType.getMemorySpace());
2973}
2974
2975Value mlir::memref::createCanonicalRankReducingSubViewOp(
2976 OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
2977 auto memrefType = llvm::cast<MemRefType>(memref.getType());
2978 unsigned rank = memrefType.getRank();
2979 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2980 SmallVector<OpFoldResult> sizes = getMixedSizes(builder&: b, loc, value: memref);
2981 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2982 auto targetType =
2983 llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
2984 targetShape, memrefType, offsets, sizes, strides));
2985 return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
2986 sizes, strides);
2987}
2988
2989FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
2990 Value value,
2991 ArrayRef<int64_t> desiredShape) {
2992 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
2993 assert(sourceMemrefType && "not a ranked memref type");
2994 auto sourceShape = sourceMemrefType.getShape();
2995 if (sourceShape.equals(desiredShape))
2996 return value;
2997 auto maybeRankReductionMask =
2998 mlir::computeRankReductionMask(sourceShape, desiredShape);
2999 if (!maybeRankReductionMask)
3000 return failure();
3001 return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3002}
3003
3004/// Helper method to check if a `subview` operation is trivially a no-op. This
3005/// is the case if the all offsets are zero, all strides are 1, and the source
3006/// shape is same as the size of the subview. In such cases, the subview can
3007/// be folded into its source.
3008static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3009 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3010 return false;
3011
3012 auto mixedOffsets = subViewOp.getMixedOffsets();
3013 auto mixedSizes = subViewOp.getMixedSizes();
3014 auto mixedStrides = subViewOp.getMixedStrides();
3015
3016 // Check offsets are zero.
3017 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3018 std::optional<int64_t> intValue = getConstantIntValue(ofr);
3019 return !intValue || intValue.value() != 0;
3020 }))
3021 return false;
3022
3023 // Check strides are one.
3024 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3025 std::optional<int64_t> intValue = getConstantIntValue(ofr);
3026 return !intValue || intValue.value() != 1;
3027 }))
3028 return false;
3029
3030 // Check all size values are static and matches the (static) source shape.
3031 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3032 for (const auto &size : llvm::enumerate(mixedSizes)) {
3033 std::optional<int64_t> intValue = getConstantIntValue(size.value());
3034 if (!intValue || *intValue != sourceShape[size.index()])
3035 return false;
3036 }
3037 // All conditions met. The `SubViewOp` is foldable as a no-op.
3038 return true;
3039}
3040
3041namespace {
3042/// Pattern to rewrite a subview op with MemRefCast arguments.
3043/// This essentially pushes memref.cast past its consuming subview when
3044/// `canFoldIntoConsumerOp` is true.
3045///
3046/// Example:
3047/// ```
3048/// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3049/// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3050/// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3051/// ```
3052/// is rewritten into:
3053/// ```
3054/// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3055/// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3056/// memref<3x4xf32, strided<[?, 1], offset: ?>>
3057/// ```
3058class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3059public:
3060 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3061
3062 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3063 PatternRewriter &rewriter) const override {
3064 // Any constant operand, just return to let SubViewOpConstantFolder kick
3065 // in.
3066 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3067 return matchPattern(value: operand, pattern: matchConstantIndex());
3068 }))
3069 return failure();
3070
3071 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3072 if (!castOp)
3073 return failure();
3074
3075 if (!CastOp::canFoldIntoConsumerOp(castOp))
3076 return failure();
3077
3078 // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3079 // the MemRefCastOp source operand type to infer the result type and the
3080 // current SubViewOp source operand type to compute the dropped dimensions
3081 // if the operation is rank-reducing.
3082 auto resultType = getCanonicalSubViewResultType(
3083 subViewOp.getType(), subViewOp.getSourceType(),
3084 llvm::cast<MemRefType>(castOp.getSource().getType()),
3085 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3086 subViewOp.getMixedStrides());
3087 if (!resultType)
3088 return failure();
3089
3090 Value newSubView = rewriter.create<SubViewOp>(
3091 subViewOp.getLoc(), resultType, castOp.getSource(),
3092 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3093 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3094 subViewOp.getStaticStrides());
3095 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3096 newSubView);
3097 return success();
3098 }
3099};
3100
3101/// Canonicalize subview ops that are no-ops. When the source shape is not
3102/// same as a result shape due to use of `affine_map`.
3103class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3104public:
3105 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3106
3107 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3108 PatternRewriter &rewriter) const override {
3109 if (!isTrivialSubViewOp(subViewOp))
3110 return failure();
3111 if (subViewOp.getSourceType() == subViewOp.getType()) {
3112 rewriter.replaceOp(subViewOp, subViewOp.getSource());
3113 return success();
3114 }
3115 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3116 subViewOp.getSource());
3117 return success();
3118 }
3119};
3120} // namespace
3121
3122/// Return the canonical type of the result of a subview.
3123struct SubViewReturnTypeCanonicalizer {
3124 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3125 ArrayRef<OpFoldResult> mixedSizes,
3126 ArrayRef<OpFoldResult> mixedStrides) {
3127 // Infer a memref type without taking into account any rank reductions.
3128 auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
3129 mixedSizes, mixedStrides);
3130 if (!resTy)
3131 return {};
3132 MemRefType nonReducedType = cast<MemRefType>(resTy);
3133
3134 // Directly return the non-rank reduced type if there are no dropped dims.
3135 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3136 if (droppedDims.none())
3137 return nonReducedType;
3138
3139 // Take the strides and offset from the non-rank reduced type.
3140 auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType);
3141
3142 // Drop dims from shape and strides.
3143 SmallVector<int64_t> targetShape;
3144 SmallVector<int64_t> targetStrides;
3145 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3146 if (droppedDims.test(Idx: i))
3147 continue;
3148 targetStrides.push_back(Elt: nonReducedStrides[i]);
3149 targetShape.push_back(Elt: nonReducedType.getDimSize(i));
3150 }
3151
3152 return MemRefType::get(targetShape, nonReducedType.getElementType(),
3153 StridedLayoutAttr::get(nonReducedType.getContext(),
3154 offset, targetStrides),
3155 nonReducedType.getMemorySpace());
3156 }
3157};
3158
3159/// A canonicalizer wrapper to replace SubViewOps.
3160struct SubViewCanonicalizer {
3161 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3162 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3163 }
3164};
3165
3166void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3167 MLIRContext *context) {
3168 results
3169 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
3170 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
3171 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3172}
3173
3174OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3175 auto resultShapedType = llvm::cast<ShapedType>(getResult().getType());
3176 auto sourceShapedType = llvm::cast<ShapedType>(getSource().getType());
3177
3178 if (resultShapedType.hasStaticShape() &&
3179 resultShapedType == sourceShapedType) {
3180 return getViewSource();
3181 }
3182
3183 // Fold subview(subview(x)), where both subviews have the same size and the
3184 // second subview's offsets are all zero. (I.e., the second subview is a
3185 // no-op.)
3186 if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3187 auto srcSizes = srcSubview.getMixedSizes();
3188 auto sizes = getMixedSizes();
3189 auto offsets = getMixedOffsets();
3190 bool allOffsetsZero = llvm::all_of(
3191 offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
3192 auto strides = getMixedStrides();
3193 bool allStridesOne = llvm::all_of(
3194 strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
3195 bool allSizesSame = llvm::equal(sizes, srcSizes);
3196 if (allOffsetsZero && allStridesOne && allSizesSame &&
3197 resultShapedType == sourceShapedType)
3198 return getViewSource();
3199 }
3200
3201 return {};
3202}
3203
3204//===----------------------------------------------------------------------===//
3205// TransposeOp
3206//===----------------------------------------------------------------------===//
3207
3208void TransposeOp::getAsmResultNames(
3209 function_ref<void(Value, StringRef)> setNameFn) {
3210 setNameFn(getResult(), "transpose");
3211}
3212
3213/// Build a strided memref type by applying `permutationMap` to `memRefType`.
3214static MemRefType inferTransposeResultType(MemRefType memRefType,
3215 AffineMap permutationMap) {
3216 auto originalSizes = memRefType.getShape();
3217 auto [originalStrides, offset] = getStridesAndOffset(memRefType);
3218 assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3219
3220 // Compute permuted sizes and strides.
3221 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3222 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3223
3224 return MemRefType::Builder(memRefType)
3225 .setShape(sizes)
3226 .setLayout(
3227 StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3228}
3229
3230void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3231 AffineMapAttr permutation,
3232 ArrayRef<NamedAttribute> attrs) {
3233 auto permutationMap = permutation.getValue();
3234 assert(permutationMap);
3235
3236 auto memRefType = llvm::cast<MemRefType>(in.getType());
3237 // Compute result type.
3238 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3239
3240 build(b, result, resultType, in, attrs);
3241 result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3242}
3243
3244// transpose $in $permutation attr-dict : type($in) `to` type(results)
3245void TransposeOp::print(OpAsmPrinter &p) {
3246 p << " " << getIn() << " " << getPermutation();
3247 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3248 p << " : " << getIn().getType() << " to " << getType();
3249}
3250
3251ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3252 OpAsmParser::UnresolvedOperand in;
3253 AffineMap permutation;
3254 MemRefType srcType, dstType;
3255 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3256 parser.parseOptionalAttrDict(result.attributes) ||
3257 parser.parseColonType(srcType) ||
3258 parser.resolveOperand(in, srcType, result.operands) ||
3259 parser.parseKeywordType("to", dstType) ||
3260 parser.addTypeToList(dstType, result.types))
3261 return failure();
3262
3263 result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3264 AffineMapAttr::get(permutation));
3265 return success();
3266}
3267
3268LogicalResult TransposeOp::verify() {
3269 if (!getPermutation().isPermutation())
3270 return emitOpError("expected a permutation map");
3271 if (getPermutation().getNumDims() != getIn().getType().getRank())
3272 return emitOpError("expected a permutation map of same rank as the input");
3273
3274 auto srcType = llvm::cast<MemRefType>(getIn().getType());
3275 auto resultType = llvm::cast<MemRefType>(getType());
3276 auto canonicalResultType = canonicalizeStridedLayout(
3277 inferTransposeResultType(srcType, getPermutation()));
3278
3279 if (canonicalizeStridedLayout(resultType) != canonicalResultType)
3280 return emitOpError("result type ")
3281 << resultType
3282 << " is not equivalent to the canonical transposed input type "
3283 << canonicalResultType;
3284 return success();
3285}
3286
3287OpFoldResult TransposeOp::fold(FoldAdaptor) {
3288 // First check for identity permutation, we can fold it away if input and
3289 // result types are identical already.
3290 if (getPermutation().isIdentity() && getType() == getIn().getType())
3291 return getIn();
3292 // Fold two consecutive memref.transpose Ops into one by composing their
3293 // permutation maps.
3294 if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3295 AffineMap composedPermutation =
3296 getPermutation().compose(otherTransposeOp.getPermutation());
3297 getInMutable().assign(otherTransposeOp.getIn());
3298 setPermutation(composedPermutation);
3299 return getResult();
3300 }
3301 return {};
3302}
3303
3304//===----------------------------------------------------------------------===//
3305// ViewOp
3306//===----------------------------------------------------------------------===//
3307
3308void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3309 setNameFn(getResult(), "view");
3310}
3311
3312LogicalResult ViewOp::verify() {
3313 auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3314 auto viewType = getType();
3315
3316 // The base memref should have identity layout map (or none).
3317 if (!baseType.getLayout().isIdentity())
3318 return emitError("unsupported map for base memref type ") << baseType;
3319
3320 // The result memref should have identity layout map (or none).
3321 if (!viewType.getLayout().isIdentity())
3322 return emitError("unsupported map for result memref type ") << viewType;
3323
3324 // The base memref and the view memref should be in the same memory space.
3325 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3326 return emitError("different memory spaces specified for base memref "
3327 "type ")
3328 << baseType << " and view memref type " << viewType;
3329
3330 // Verify that we have the correct number of sizes for the result type.
3331 unsigned numDynamicDims = viewType.getNumDynamicDims();
3332 if (getSizes().size() != numDynamicDims)
3333 return emitError("incorrect number of size operands for type ") << viewType;
3334
3335 return success();
3336}
3337
3338Value ViewOp::getViewSource() { return getSource(); }
3339
3340namespace {
3341
3342struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3343 using OpRewritePattern<ViewOp>::OpRewritePattern;
3344
3345 LogicalResult matchAndRewrite(ViewOp viewOp,
3346 PatternRewriter &rewriter) const override {
3347 // Return if none of the operands are constants.
3348 if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3349 return matchPattern(value: operand, pattern: matchConstantIndex());
3350 }))
3351 return failure();
3352
3353 // Get result memref type.
3354 auto memrefType = viewOp.getType();
3355
3356 // Get offset from old memref view type 'memRefType'.
3357 int64_t oldOffset;
3358 SmallVector<int64_t, 4> oldStrides;
3359 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
3360 return failure();
3361 assert(oldOffset == 0 && "Expected 0 offset");
3362
3363 SmallVector<Value, 4> newOperands;
3364
3365 // Offset cannot be folded into result type.
3366
3367 // Fold any dynamic dim operands which are produced by a constant.
3368 SmallVector<int64_t, 4> newShapeConstants;
3369 newShapeConstants.reserve(N: memrefType.getRank());
3370
3371 unsigned dynamicDimPos = 0;
3372 unsigned rank = memrefType.getRank();
3373 for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3374 int64_t dimSize = memrefType.getDimSize(dim);
3375 // If this is already static dimension, keep it.
3376 if (!ShapedType::isDynamic(dimSize)) {
3377 newShapeConstants.push_back(Elt: dimSize);
3378 continue;
3379 }
3380 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3381 if (auto constantIndexOp =
3382 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3383 // Dynamic shape dimension will be folded.
3384 newShapeConstants.push_back(Elt: constantIndexOp.value());
3385 } else {
3386 // Dynamic shape dimension not folded; copy operand from old memref.
3387 newShapeConstants.push_back(Elt: dimSize);
3388 newOperands.push_back(Elt: viewOp.getSizes()[dynamicDimPos]);
3389 }
3390 dynamicDimPos++;
3391 }
3392
3393 // Create new memref type with constant folded dims.
3394 MemRefType newMemRefType =
3395 MemRefType::Builder(memrefType).setShape(newShapeConstants);
3396 // Nothing new, don't fold.
3397 if (newMemRefType == memrefType)
3398 return failure();
3399
3400 // Create new ViewOp.
3401 auto newViewOp = rewriter.create<ViewOp>(
3402 viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3403 viewOp.getByteShift(), newOperands);
3404 // Insert a cast so we have the same type as the old memref type.
3405 rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3406 return success();
3407 }
3408};
3409
3410struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3411 using OpRewritePattern<ViewOp>::OpRewritePattern;
3412
3413 LogicalResult matchAndRewrite(ViewOp viewOp,
3414 PatternRewriter &rewriter) const override {
3415 Value memrefOperand = viewOp.getOperand(0);
3416 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3417 if (!memrefCastOp)
3418 return failure();
3419 Value allocOperand = memrefCastOp.getOperand();
3420 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3421 if (!allocOp)
3422 return failure();
3423 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3424 viewOp.getByteShift(),
3425 viewOp.getSizes());
3426 return success();
3427 }
3428};
3429
3430} // namespace
3431
3432void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3433 MLIRContext *context) {
3434 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3435}
3436
3437//===----------------------------------------------------------------------===//
3438// AtomicRMWOp
3439//===----------------------------------------------------------------------===//
3440
3441LogicalResult AtomicRMWOp::verify() {
3442 if (getMemRefType().getRank() != getNumOperands() - 2)
3443 return emitOpError(
3444 "expects the number of subscripts to be equal to memref rank");
3445 switch (getKind()) {
3446 case arith::AtomicRMWKind::addf:
3447 case arith::AtomicRMWKind::maximumf:
3448 case arith::AtomicRMWKind::minimumf:
3449 case arith::AtomicRMWKind::mulf:
3450 if (!llvm::isa<FloatType>(getValue().getType()))
3451 return emitOpError() << "with kind '"
3452 << arith::stringifyAtomicRMWKind(getKind())
3453 << "' expects a floating-point type";
3454 break;
3455 case arith::AtomicRMWKind::addi:
3456 case arith::AtomicRMWKind::maxs:
3457 case arith::AtomicRMWKind::maxu:
3458 case arith::AtomicRMWKind::mins:
3459 case arith::AtomicRMWKind::minu:
3460 case arith::AtomicRMWKind::muli:
3461 case arith::AtomicRMWKind::ori:
3462 case arith::AtomicRMWKind::andi:
3463 if (!llvm::isa<IntegerType>(getValue().getType()))
3464 return emitOpError() << "with kind '"
3465 << arith::stringifyAtomicRMWKind(getKind())
3466 << "' expects an integer type";
3467 break;
3468 default:
3469 break;
3470 }
3471 return success();
3472}
3473
3474OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3475 /// atomicrmw(memrefcast) -> atomicrmw
3476 if (succeeded(foldMemRefCast(*this, getValue())))
3477 return getResult();
3478 return OpFoldResult();
3479}
3480
3481//===----------------------------------------------------------------------===//
3482// TableGen'd op method definitions
3483//===----------------------------------------------------------------------===//
3484
3485#define GET_OP_CLASSES
3486#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
3487

source code of mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp