1//===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/Dialect/SCF/IR/SCF.h"
14#include "mlir/Dialect/Shape/IR/Shape.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/IR/IRMapping.h"
17#include "mlir/IR/ImplicitLocOpBuilder.h"
18#include "mlir/Pass/Pass.h"
19#include "mlir/Transforms/DialectConversion.h"
20#include "llvm/ADT/STLExtras.h"
21
22namespace mlir {
23#define GEN_PASS_DEF_CONVERTSHAPETOSTANDARD
24#include "mlir/Conversion/Passes.h.inc"
25} // namespace mlir
26
27using namespace mlir;
28using namespace mlir::shape;
29using namespace mlir::scf;
30
31/// Conversion patterns.
32namespace {
33class AnyOpConversion : public OpConversionPattern<AnyOp> {
34public:
35 using OpConversionPattern<AnyOp>::OpConversionPattern;
36
37 LogicalResult
38 matchAndRewrite(AnyOp op, OpAdaptor adaptor,
39 ConversionPatternRewriter &rewriter) const override;
40};
41} // namespace
42
43LogicalResult
44AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
45 ConversionPatternRewriter &rewriter) const {
46 // Replace `any` with its first operand.
47 // Any operand would be a valid substitution.
48 rewriter.replaceOp(op, {adaptor.getInputs().front()});
49 return success();
50}
51
52namespace {
53template <typename SrcOpTy, typename DstOpTy>
54class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
55public:
56 using OpConversionPattern<SrcOpTy>::OpConversionPattern;
57
58 LogicalResult
59 matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor,
60 ConversionPatternRewriter &rewriter) const override {
61 // For now, only error-free types are supported by this lowering.
62 if (isa<SizeType>(op.getType()))
63 return failure();
64
65 rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(),
66 adaptor.getRhs());
67 return success();
68 }
69};
70} // namespace
71
72namespace {
73struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
74 using OpConversionPattern<BroadcastOp>::OpConversionPattern;
75
76 LogicalResult
77 matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
78 ConversionPatternRewriter &rewriter) const override;
79};
80
81// Get the resulting extent in a given dimension. This is computed with any
82// number of extent tensors and shifted offsets into them.
83Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
84 ValueRange rankDiffs, Value outputDimension) {
85 Value one = lb.create<arith::ConstantIndexOp>(args: 1);
86 Value broadcastedDim = one;
87 for (auto tup : llvm::zip(t&: extentTensors, u&: rankDiffs)) {
88 Value shape = std::get<0>(t&: tup);
89 Value rankDiff = std::get<1>(t&: tup);
90 Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult,
91 outputDimension, rankDiff);
92 Type indexTy = lb.getIndexType();
93 broadcastedDim =
94 lb.create<IfOp>(
95 outOfBounds,
96 [&](OpBuilder &b, Location loc) {
97 b.create<scf::YieldOp>(loc, broadcastedDim);
98 },
99 [&](OpBuilder &b, Location loc) {
100 // The broadcasting logic is:
101 // - if one extent (here we arbitrarily choose the
102 // extent from the greater-rank operand) is equal to 1,
103 // then take the extent from the other operand
104 // - otherwise, take the extent as-is.
105 // Note that this logic remains correct in the presence
106 // of dimensions of zero extent.
107 Value lesserRankOperandDimension = b.create<arith::SubIOp>(
108 loc, indexTy, outputDimension, rankDiff);
109 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
110 loc, shape, ValueRange{lesserRankOperandDimension});
111
112 Value dimIsOne =
113 b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
114 lesserRankOperandExtent, one);
115 Value dim = b.create<arith::SelectOp>(
116 loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
117 b.create<scf::YieldOp>(loc, dim);
118 })
119 .getResult(0);
120 }
121 return broadcastedDim;
122}
123} // namespace
124
125LogicalResult BroadcastOpConverter::matchAndRewrite(
126 BroadcastOp op, OpAdaptor adaptor,
127 ConversionPatternRewriter &rewriter) const {
128 // For now, this lowering is only defined on `tensor<?xindex>` operands, not
129 // on shapes.
130 if (isa<ShapeType>(op.getType()))
131 return failure();
132
133 auto loc = op.getLoc();
134 ImplicitLocOpBuilder lb(loc, rewriter);
135
136 Value zero = lb.create<arith::ConstantIndexOp>(args: 0);
137 Type indexTy = lb.getIndexType();
138
139 // Save all the ranks for bounds checking. Because this is a tensor
140 // representing the shape extents, the rank is the extent of the only
141 // dimension in the tensor.
142 SmallVector<Value> ranks, rankDiffs;
143 llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
144 return lb.create<tensor::DimOp>(v, zero);
145 }));
146
147 // Find the maximum rank
148 Value maxRank = ranks.front();
149 for (Value v : llvm::drop_begin(RangeOrContainer&: ranks, N: 1)) {
150 maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
151 }
152
153 // Calculate the difference of ranks and the maximum rank for later offsets.
154 llvm::append_range(C&: rankDiffs, R: llvm::map_range(C&: ranks, F: [&](Value v) {
155 return lb.create<arith::SubIOp>(indexTy, maxRank, v);
156 }));
157
158 Value replacement = lb.create<tensor::GenerateOp>(
159 getExtentTensorType(lb.getContext()), ValueRange{maxRank},
160 [&](OpBuilder &b, Location loc, ValueRange args) {
161 Value broadcastedDim =
162 getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
163 rankDiffs, args[0]);
164
165 b.create<tensor::YieldOp>(loc, broadcastedDim);
166 });
167 if (replacement.getType() != op.getType())
168 replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
169 rewriter.replaceOp(op, replacement);
170 return success();
171}
172
173namespace {
174class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
175public:
176 using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
177
178 LogicalResult
179 matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
180 ConversionPatternRewriter &rewriter) const override;
181};
182} // namespace
183
184LogicalResult ConstShapeOpConverter::matchAndRewrite(
185 ConstShapeOp op, OpAdaptor adaptor,
186 ConversionPatternRewriter &rewriter) const {
187
188 // For now, this lowering supports only extent tensors, not `shape.shape`
189 // types.
190 if (isa<ShapeType>(op.getType()))
191 return failure();
192
193 auto loc = op.getLoc();
194 SmallVector<Value, 4> extentOperands;
195 for (auto extent : op.getShape()) {
196 extentOperands.push_back(
197 rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
198 }
199 Type resultTy =
200 RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
201 Value tensor =
202 rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands);
203 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
204 return success();
205}
206
207namespace {
208class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
209public:
210 using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
211
212 LogicalResult
213 matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
214 ConversionPatternRewriter &rewriter) const override;
215};
216} // namespace
217
218LogicalResult ConstSizeOpConversion::matchAndRewrite(
219 ConstSizeOp op, OpAdaptor adaptor,
220 ConversionPatternRewriter &rewriter) const {
221 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
222 op, op.getValue().getSExtValue());
223 return success();
224}
225
226namespace {
227struct IsBroadcastableOpConverter
228 : public OpConversionPattern<IsBroadcastableOp> {
229 using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
230
231 LogicalResult
232 matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
233 ConversionPatternRewriter &rewriter) const override;
234};
235} // namespace
236
237LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
238 IsBroadcastableOp op, OpAdaptor adaptor,
239 ConversionPatternRewriter &rewriter) const {
240 // For now, this lowering is only defined on `tensor<?xindex>` operands, not
241 // on shapes.
242 if (!llvm::all_of(op.getShapes(),
243 [](Value v) { return !isa<ShapeType>(v.getType()); }))
244 return failure();
245
246 auto loc = op.getLoc();
247 ImplicitLocOpBuilder lb(loc, rewriter);
248 Value zero = lb.create<arith::ConstantIndexOp>(args: 0);
249 Value one = lb.create<arith::ConstantIndexOp>(args: 1);
250 Type indexTy = lb.getIndexType();
251
252 // Save all the ranks for bounds checking. Because this is a tensor
253 // representing the shape extents, the rank is the extent of the only
254 // dimension in the tensor.
255 SmallVector<Value> ranks, rankDiffs;
256 llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
257 return lb.create<tensor::DimOp>(v, zero);
258 }));
259
260 // Find the maximum rank
261 Value maxRank = ranks.front();
262 for (Value v : llvm::drop_begin(RangeOrContainer&: ranks, N: 1)) {
263 maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
264 }
265
266 // Calculate the difference of ranks and the maximum rank for later offsets.
267 llvm::append_range(C&: rankDiffs, R: llvm::map_range(C&: ranks, F: [&](Value v) {
268 return lb.create<arith::SubIOp>(indexTy, maxRank, v);
269 }));
270
271 Type i1Ty = rewriter.getI1Type();
272 Value trueVal =
273 rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
274
275 auto reduceResult = lb.create<ForOp>(
276 loc, zero, maxRank, one, ValueRange{trueVal},
277 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
278 // Find a non-1 dim, if it exists. Note that the first part of this
279 // could reuse the Broadcast lowering entirely, but we redo the work
280 // here to make optimizations easier between the two loops.
281 Value broadcastedDim = getBroadcastedDim(
282 ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv);
283
284 Value broadcastable = iterArgs[0];
285 for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
286 Value shape, rankDiff;
287 std::tie(shape, rankDiff) = tup;
288 Value outOfBounds = b.create<arith::CmpIOp>(
289 loc, arith::CmpIPredicate::ult, iv, rankDiff);
290 broadcastable =
291 b.create<IfOp>(
292 loc, outOfBounds,
293 [&](OpBuilder &b, Location loc) {
294 // Non existent dimensions are always broadcastable
295 b.create<scf::YieldOp>(loc, broadcastable);
296 },
297 [&](OpBuilder &b, Location loc) {
298 // Every value needs to be either 1, or the same non-1
299 // value to be broadcastable in this dim.
300 Value operandDimension =
301 b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff);
302 Value dimensionExtent = b.create<tensor::ExtractOp>(
303 loc, shape, ValueRange{operandDimension});
304
305 Value equalOne = b.create<arith::CmpIOp>(
306 loc, arith::CmpIPredicate::eq, dimensionExtent, one);
307 Value equalBroadcasted = b.create<arith::CmpIOp>(
308 loc, arith::CmpIPredicate::eq, dimensionExtent,
309 broadcastedDim);
310 Value result = b.create<arith::AndIOp>(
311 loc, broadcastable,
312 b.create<arith::OrIOp>(loc, equalOne,
313 equalBroadcasted));
314 b.create<scf::YieldOp>(loc, result);
315 })
316 .getResult(0);
317 }
318
319 b.create<scf::YieldOp>(loc, broadcastable);
320 });
321
322 rewriter.replaceOp(op, reduceResult.getResults().front());
323 return success();
324}
325
326namespace {
327class DimOpConverter : public OpConversionPattern<DimOp> {
328 using OpConversionPattern<DimOp>::OpConversionPattern;
329
330 LogicalResult
331 matchAndRewrite(DimOp op, OpAdaptor adaptor,
332 ConversionPatternRewriter &rewriter) const override;
333};
334} // namespace
335
336LogicalResult
337DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
338 ConversionPatternRewriter &rewriter) const {
339 // Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further
340 // lowerings. This can be further optimized if needed to avoid intermediate
341 // steps.
342 auto shapeOf = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getValue());
343 rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
344 op.getIndex());
345 return success();
346}
347
348namespace {
349class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
350 using OpConversionPattern<GetExtentOp>::OpConversionPattern;
351
352 LogicalResult
353 matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
354 ConversionPatternRewriter &rewriter) const override;
355};
356} // namespace
357
358LogicalResult GetExtentOpConverter::matchAndRewrite(
359 GetExtentOp op, OpAdaptor adaptor,
360 ConversionPatternRewriter &rewriter) const {
361 // For now, only error-free types are supported by this lowering.
362 if (isa<SizeType>(op.getType()))
363 return failure();
364
365 // Derive shape extent directly from shape origin if possible. This
366 // circumvents the necessity to materialize the shape in memory.
367 if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) {
368 if (isa<ShapedType>(shapeOfOp.getArg().getType())) {
369 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(),
370 adaptor.getDim());
371 return success();
372 }
373 }
374
375 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
376 adaptor.getShape(),
377 ValueRange{adaptor.getDim()});
378 return success();
379}
380
381namespace {
382class RankOpConverter : public OpConversionPattern<shape::RankOp> {
383public:
384 using OpConversionPattern<shape::RankOp>::OpConversionPattern;
385
386 LogicalResult
387 matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
388 ConversionPatternRewriter &rewriter) const override;
389};
390} // namespace
391
392LogicalResult
393RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
394 ConversionPatternRewriter &rewriter) const {
395 // For now, this lowering supports only error-free types.
396 if (isa<SizeType>(op.getType()))
397 return failure();
398
399 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0);
400 return success();
401}
402
403namespace {
404/// Converts `shape.reduce` to `scf.for`.
405struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
406public:
407 using OpConversionPattern::OpConversionPattern;
408
409 LogicalResult
410 matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
411 ConversionPatternRewriter &rewriter) const final;
412};
413} // namespace
414
415LogicalResult
416ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
417 ConversionPatternRewriter &rewriter) const {
418 // For now, this lowering is only defined on `tensor<?xindex>` operands.
419 if (isa<ShapeType>(op.getShape().getType()))
420 return failure();
421
422 auto loc = op.getLoc();
423
424 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
425 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
426 Type indexTy = rewriter.getIndexType();
427 Value rank =
428 rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero);
429
430 auto loop = rewriter.create<scf::ForOp>(
431 loc, zero, rank, one, op.getInitVals(),
432 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
433 Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv);
434
435 SmallVector<Value, 2> mappedValues{iv, extent};
436 mappedValues.append(args.begin(), args.end());
437
438 IRMapping mapping;
439 Block *reduceBody = op.getBody();
440 mapping.map(reduceBody->getArguments(), mappedValues);
441 for (auto &nested : reduceBody->without_terminator())
442 b.clone(nested, mapping);
443
444 SmallVector<Value, 2> mappedResults;
445 for (auto result : reduceBody->getTerminator()->getOperands())
446 mappedResults.push_back(mapping.lookup(result));
447 b.create<scf::YieldOp>(loc, mappedResults);
448 });
449
450 rewriter.replaceOp(op, loop.getResults());
451 return success();
452}
453
454namespace {
455/// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
456/// only defined on `tensor<?xindex>` operands. The test for equality first
457/// compares their size and, if equal, checks every extent for equality.
458///
459/// Example:
460///
461/// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
462///
463/// becomes
464///
465/// %c0 = arith.constant 0 : index
466/// %0 = dim %arg0, %c0 : tensor<?xindex>
467/// %1 = dim %arg1, %c0 : tensor<?xindex>
468/// %2 = arith.cmpi "eq", %0, %1 : index
469/// %result = scf.if %2 -> (i1) {
470/// %c1 = arith.constant 1 : index
471/// %true = arith.constant true
472/// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
473/// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
474/// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
475/// %7 = arith.cmpi "eq", %5, %6 : index
476/// %8 = arith.andi %arg3, %7 : i1
477/// scf.yield %8 : i1
478/// }
479/// scf.yield %4 : i1
480/// } else {
481/// %false = arith.constant false
482/// scf.yield %false : i1
483/// }
484///
485struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
486 using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
487
488 LogicalResult
489 matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
490 ConversionPatternRewriter &rewriter) const override;
491};
492} // namespace
493
494LogicalResult
495ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
496 ConversionPatternRewriter &rewriter) const {
497 if (!llvm::all_of(op.getShapes(),
498 [](Value v) { return !isa<ShapeType>(v.getType()); }))
499 return failure();
500
501 Type i1Ty = rewriter.getI1Type();
502 if (op.getShapes().size() <= 1) {
503 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty,
504 rewriter.getBoolAttr(true));
505 return success();
506 }
507
508 auto loc = op.getLoc();
509 Type indexTy = rewriter.getIndexType();
510 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
511 Value firstShape = adaptor.getShapes().front();
512 Value firstRank =
513 rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
514 Value result = nullptr;
515 // Generate a linear sequence of compares, all with firstShape as lhs.
516 for (Value shape : adaptor.getShapes().drop_front(1)) {
517 Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
518 Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
519 firstRank, rank);
520 auto same = rewriter.create<IfOp>(
521 loc, eqRank,
522 [&](OpBuilder &b, Location loc) {
523 Value one = b.create<arith::ConstantIndexOp>(loc, 1);
524 Value init =
525 b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
526 auto loop = b.create<scf::ForOp>(
527 loc, zero, firstRank, one, ValueRange{init},
528 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
529 Value conj = args[0];
530 Value lhsExtent =
531 b.create<tensor::ExtractOp>(loc, firstShape, iv);
532 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
533 Value eqExtent = b.create<arith::CmpIOp>(
534 loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
535 Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent);
536 b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
537 });
538 b.create<scf::YieldOp>(loc, loop.getResults());
539 },
540 [&](OpBuilder &b, Location loc) {
541 Value result =
542 b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
543 b.create<scf::YieldOp>(loc, result);
544 });
545 result = !result ? same.getResult(0)
546 : rewriter.create<arith::AndIOp>(loc, result,
547 same.getResult(0));
548 }
549 rewriter.replaceOp(op, result);
550 return success();
551}
552
553namespace {
554class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
555public:
556 using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
557
558 LogicalResult
559 matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
560 ConversionPatternRewriter &rewriter) const override;
561};
562} // namespace
563
564LogicalResult ShapeOfOpConversion::matchAndRewrite(
565 ShapeOfOp op, OpAdaptor adaptor,
566 ConversionPatternRewriter &rewriter) const {
567
568 // For now, only error-free types are supported by this lowering.
569 if (isa<ShapeType>(op.getType()))
570 return failure();
571
572 // For ranked tensor arguments, lower to `tensor.from_elements`.
573 auto loc = op.getLoc();
574 Value tensor = adaptor.getArg();
575 Type tensorTy = tensor.getType();
576 if (isa<RankedTensorType>(Val: tensorTy)) {
577
578 // Build values for individual extents.
579 SmallVector<Value, 8> extentValues;
580 RankedTensorType rankedTensorTy = cast<RankedTensorType>(tensorTy);
581 int64_t rank = rankedTensorTy.getRank();
582 for (int64_t i = 0; i < rank; i++) {
583 if (rankedTensorTy.isDynamicDim(i)) {
584 Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
585 extentValues.push_back(Elt: extent);
586 } else {
587 Value extent = rewriter.create<arith::ConstantIndexOp>(
588 loc, rankedTensorTy.getDimSize(i));
589 extentValues.push_back(Elt: extent);
590 }
591 }
592
593 // Materialize extent tensor.
594 Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
595 loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
596 extentValues);
597 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
598 staticExtentTensor);
599 return success();
600 }
601
602 // Lower to `tensor.generate` otherwise.
603 auto *ctx = rewriter.getContext();
604 Value rank = rewriter.create<tensor::RankOp>(loc, tensor);
605 rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
606 op, getExtentTensorType(ctx), ValueRange{rank},
607 [&](OpBuilder &b, Location loc, ValueRange args) {
608 Value dim = args.front();
609 Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
610 b.create<tensor::YieldOp>(loc, extent);
611 });
612
613 return success();
614}
615
616namespace {
617class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
618public:
619 using OpConversionPattern<SplitAtOp>::OpConversionPattern;
620
621 LogicalResult
622 matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
623 ConversionPatternRewriter &rewriter) const override;
624};
625} // namespace
626
627LogicalResult SplitAtOpConversion::matchAndRewrite(
628 SplitAtOp op, OpAdaptor adaptor,
629 ConversionPatternRewriter &rewriter) const {
630 // Error conditions are not implemented, only lower if all operands and
631 // results are extent tensors.
632 if (llvm::any_of(Range: ValueRange{op.getOperand(), op.getHead(), op.getTail()},
633 P: [](Value v) { return isa<ShapeType>(v.getType()); }))
634 return failure();
635
636 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
637 Value zero = b.create<arith::ConstantIndexOp>(args: 0);
638 Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero);
639
640 // index < 0 ? index + rank : index
641 Value originalIndex = adaptor.getIndex();
642 Value add = b.create<arith::AddIOp>(originalIndex, rank);
643 Value indexIsNegative =
644 b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero);
645 Value index = b.create<arith::SelectOp>(indexIsNegative, add, originalIndex);
646
647 Value one = b.create<arith::ConstantIndexOp>(args: 1);
648 Value head =
649 b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one);
650 Value tailSize = b.create<arith::SubIOp>(rank, index);
651 Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index,
652 tailSize, one);
653 rewriter.replaceOp(op, {head, tail});
654 return success();
655}
656
657namespace {
658class ToExtentTensorOpConversion
659 : public OpConversionPattern<ToExtentTensorOp> {
660public:
661 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
662
663 LogicalResult
664 matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
665 ConversionPatternRewriter &rewriter) const override {
666 if (!isa<RankedTensorType>(adaptor.getInput().getType()))
667 return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
668
669 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
670 adaptor.getInput());
671 return success();
672 }
673};
674} // namespace
675
676namespace {
677/// Import the Shape Ops to Std Patterns.
678#include "ShapeToStandard.cpp.inc"
679} // namespace
680
681namespace {
682/// Conversion pass.
683class ConvertShapeToStandardPass
684 : public impl::ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
685
686 void runOnOperation() override;
687};
688} // namespace
689
690void ConvertShapeToStandardPass::runOnOperation() {
691 // Setup target legality.
692 MLIRContext &ctx = getContext();
693 ConversionTarget target(ctx);
694 target.addLegalDialect<arith::ArithDialect, SCFDialect,
695 tensor::TensorDialect>();
696 target.addLegalOp<CstrRequireOp, func::FuncOp, ModuleOp>();
697
698 // Setup conversion patterns.
699 RewritePatternSet patterns(&ctx);
700 populateShapeToStandardConversionPatterns(patterns);
701
702 // Apply conversion.
703 auto module = getOperation();
704 if (failed(applyPartialConversion(module, target, std::move(patterns))))
705 signalPassFailure();
706}
707
708void mlir::populateShapeToStandardConversionPatterns(
709 RewritePatternSet &patterns) {
710 // clang-format off
711 populateWithGenerated(patterns);
712 patterns.add<
713 AnyOpConversion,
714 BinaryOpConversion<AddOp, arith::AddIOp>,
715 BinaryOpConversion<MulOp, arith::MulIOp>,
716 BroadcastOpConverter,
717 ConstShapeOpConverter,
718 ConstSizeOpConversion,
719 DimOpConverter,
720 IsBroadcastableOpConverter,
721 GetExtentOpConverter,
722 RankOpConverter,
723 ReduceOpConverter,
724 ShapeEqOpConverter,
725 ShapeOfOpConversion,
726 SplitAtOpConversion,
727 ToExtentTensorOpConversion>(patterns.getContext());
728 // clang-format on
729}
730
731std::unique_ptr<OperationPass<ModuleOp>>
732mlir::createConvertShapeToStandardPass() {
733 return std::make_unique<ConvertShapeToStandardPass>();
734}
735

source code of mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp