1 | //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" |
10 | |
11 | #include "mlir/Dialect/Arith/IR/Arith.h" |
12 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
13 | #include "mlir/Dialect/SCF/IR/SCF.h" |
14 | #include "mlir/Dialect/Shape/IR/Shape.h" |
15 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
16 | #include "mlir/IR/IRMapping.h" |
17 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
18 | #include "mlir/Pass/Pass.h" |
19 | #include "mlir/Transforms/DialectConversion.h" |
20 | #include "llvm/ADT/STLExtras.h" |
21 | |
22 | namespace mlir { |
23 | #define GEN_PASS_DEF_CONVERTSHAPETOSTANDARD |
24 | #include "mlir/Conversion/Passes.h.inc" |
25 | } // namespace mlir |
26 | |
27 | using namespace mlir; |
28 | using namespace mlir::shape; |
29 | using namespace mlir::scf; |
30 | |
31 | /// Conversion patterns. |
32 | namespace { |
33 | class AnyOpConversion : public OpConversionPattern<AnyOp> { |
34 | public: |
35 | using OpConversionPattern<AnyOp>::OpConversionPattern; |
36 | |
37 | LogicalResult |
38 | matchAndRewrite(AnyOp op, OpAdaptor adaptor, |
39 | ConversionPatternRewriter &rewriter) const override; |
40 | }; |
41 | } // namespace |
42 | |
43 | LogicalResult |
44 | AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor, |
45 | ConversionPatternRewriter &rewriter) const { |
46 | // Replace `any` with its first operand. |
47 | // Any operand would be a valid substitution. |
48 | rewriter.replaceOp(op, {adaptor.getInputs().front()}); |
49 | return success(); |
50 | } |
51 | |
52 | namespace { |
53 | template <typename SrcOpTy, typename DstOpTy> |
54 | class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { |
55 | public: |
56 | using OpConversionPattern<SrcOpTy>::OpConversionPattern; |
57 | |
58 | LogicalResult |
59 | matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor, |
60 | ConversionPatternRewriter &rewriter) const override { |
61 | // For now, only error-free types are supported by this lowering. |
62 | if (isa<SizeType>(op.getType())) |
63 | return failure(); |
64 | |
65 | rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(), |
66 | adaptor.getRhs()); |
67 | return success(); |
68 | } |
69 | }; |
70 | } // namespace |
71 | |
72 | namespace { |
73 | struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { |
74 | using OpConversionPattern<BroadcastOp>::OpConversionPattern; |
75 | |
76 | LogicalResult |
77 | matchAndRewrite(BroadcastOp op, OpAdaptor adaptor, |
78 | ConversionPatternRewriter &rewriter) const override; |
79 | }; |
80 | |
81 | // Get the resulting extent in a given dimension. This is computed with any |
82 | // number of extent tensors and shifted offsets into them. |
83 | Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors, |
84 | ValueRange rankDiffs, Value outputDimension) { |
85 | Value one = lb.create<arith::ConstantIndexOp>(args: 1); |
86 | Value broadcastedDim = one; |
87 | for (auto tup : llvm::zip(t&: extentTensors, u&: rankDiffs)) { |
88 | Value shape = std::get<0>(t&: tup); |
89 | Value rankDiff = std::get<1>(t&: tup); |
90 | Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult, |
91 | outputDimension, rankDiff); |
92 | Type indexTy = lb.getIndexType(); |
93 | broadcastedDim = |
94 | lb.create<IfOp>( |
95 | outOfBounds, |
96 | [&](OpBuilder &b, Location loc) { |
97 | b.create<scf::YieldOp>(loc, broadcastedDim); |
98 | }, |
99 | [&](OpBuilder &b, Location loc) { |
100 | // The broadcasting logic is: |
101 | // - if one extent (here we arbitrarily choose the |
102 | // extent from the greater-rank operand) is equal to 1, |
103 | // then take the extent from the other operand |
104 | // - otherwise, take the extent as-is. |
105 | // Note that this logic remains correct in the presence |
106 | // of dimensions of zero extent. |
107 | Value lesserRankOperandDimension = b.create<arith::SubIOp>( |
108 | loc, indexTy, outputDimension, rankDiff); |
109 | Value lesserRankOperandExtent = b.create<tensor::ExtractOp>( |
110 | loc, shape, ValueRange{lesserRankOperandDimension}); |
111 | |
112 | Value dimIsOne = |
113 | b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
114 | lesserRankOperandExtent, one); |
115 | Value dim = b.create<arith::SelectOp>( |
116 | loc, dimIsOne, broadcastedDim, lesserRankOperandExtent); |
117 | b.create<scf::YieldOp>(loc, dim); |
118 | }) |
119 | .getResult(0); |
120 | } |
121 | return broadcastedDim; |
122 | } |
123 | } // namespace |
124 | |
125 | LogicalResult BroadcastOpConverter::matchAndRewrite( |
126 | BroadcastOp op, OpAdaptor adaptor, |
127 | ConversionPatternRewriter &rewriter) const { |
128 | // For now, this lowering is only defined on `tensor<?xindex>` operands, not |
129 | // on shapes. |
130 | if (isa<ShapeType>(op.getType())) |
131 | return failure(); |
132 | |
133 | auto loc = op.getLoc(); |
134 | ImplicitLocOpBuilder lb(loc, rewriter); |
135 | |
136 | Value zero = lb.create<arith::ConstantIndexOp>(args: 0); |
137 | Type indexTy = lb.getIndexType(); |
138 | |
139 | // Save all the ranks for bounds checking. Because this is a tensor |
140 | // representing the shape extents, the rank is the extent of the only |
141 | // dimension in the tensor. |
142 | SmallVector<Value> ranks, rankDiffs; |
143 | llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) { |
144 | return lb.create<tensor::DimOp>(v, zero); |
145 | })); |
146 | |
147 | // Find the maximum rank |
148 | Value maxRank = ranks.front(); |
149 | for (Value v : llvm::drop_begin(RangeOrContainer&: ranks, N: 1)) { |
150 | maxRank = lb.create<arith::MaxUIOp>(v, maxRank); |
151 | } |
152 | |
153 | // Calculate the difference of ranks and the maximum rank for later offsets. |
154 | llvm::append_range(C&: rankDiffs, R: llvm::map_range(C&: ranks, F: [&](Value v) { |
155 | return lb.create<arith::SubIOp>(indexTy, maxRank, v); |
156 | })); |
157 | |
158 | Value replacement = lb.create<tensor::GenerateOp>( |
159 | getExtentTensorType(lb.getContext()), ValueRange{maxRank}, |
160 | [&](OpBuilder &b, Location loc, ValueRange args) { |
161 | Value broadcastedDim = |
162 | getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), |
163 | rankDiffs, args[0]); |
164 | |
165 | b.create<tensor::YieldOp>(loc, broadcastedDim); |
166 | }); |
167 | if (replacement.getType() != op.getType()) |
168 | replacement = lb.create<tensor::CastOp>(op.getType(), replacement); |
169 | rewriter.replaceOp(op, replacement); |
170 | return success(); |
171 | } |
172 | |
173 | namespace { |
174 | class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { |
175 | public: |
176 | using OpConversionPattern<ConstShapeOp>::OpConversionPattern; |
177 | |
178 | LogicalResult |
179 | matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor, |
180 | ConversionPatternRewriter &rewriter) const override; |
181 | }; |
182 | } // namespace |
183 | |
184 | LogicalResult ConstShapeOpConverter::matchAndRewrite( |
185 | ConstShapeOp op, OpAdaptor adaptor, |
186 | ConversionPatternRewriter &rewriter) const { |
187 | |
188 | // For now, this lowering supports only extent tensors, not `shape.shape` |
189 | // types. |
190 | if (isa<ShapeType>(op.getType())) |
191 | return failure(); |
192 | |
193 | auto loc = op.getLoc(); |
194 | SmallVector<Value, 4> extentOperands; |
195 | for (auto extent : op.getShape()) { |
196 | extentOperands.push_back( |
197 | rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue())); |
198 | } |
199 | Type resultTy = |
200 | RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType()); |
201 | Value tensor = |
202 | rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands); |
203 | rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor); |
204 | return success(); |
205 | } |
206 | |
207 | namespace { |
208 | class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { |
209 | public: |
210 | using OpConversionPattern<ConstSizeOp>::OpConversionPattern; |
211 | |
212 | LogicalResult |
213 | matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor, |
214 | ConversionPatternRewriter &rewriter) const override; |
215 | }; |
216 | } // namespace |
217 | |
218 | LogicalResult ConstSizeOpConversion::matchAndRewrite( |
219 | ConstSizeOp op, OpAdaptor adaptor, |
220 | ConversionPatternRewriter &rewriter) const { |
221 | rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>( |
222 | op, op.getValue().getSExtValue()); |
223 | return success(); |
224 | } |
225 | |
226 | namespace { |
227 | struct IsBroadcastableOpConverter |
228 | : public OpConversionPattern<IsBroadcastableOp> { |
229 | using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern; |
230 | |
231 | LogicalResult |
232 | matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor, |
233 | ConversionPatternRewriter &rewriter) const override; |
234 | }; |
235 | } // namespace |
236 | |
237 | LogicalResult IsBroadcastableOpConverter::matchAndRewrite( |
238 | IsBroadcastableOp op, OpAdaptor adaptor, |
239 | ConversionPatternRewriter &rewriter) const { |
240 | // For now, this lowering is only defined on `tensor<?xindex>` operands, not |
241 | // on shapes. |
242 | if (!llvm::all_of(op.getShapes(), |
243 | [](Value v) { return !isa<ShapeType>(v.getType()); })) |
244 | return failure(); |
245 | |
246 | auto loc = op.getLoc(); |
247 | ImplicitLocOpBuilder lb(loc, rewriter); |
248 | Value zero = lb.create<arith::ConstantIndexOp>(args: 0); |
249 | Value one = lb.create<arith::ConstantIndexOp>(args: 1); |
250 | Type indexTy = lb.getIndexType(); |
251 | |
252 | // Save all the ranks for bounds checking. Because this is a tensor |
253 | // representing the shape extents, the rank is the extent of the only |
254 | // dimension in the tensor. |
255 | SmallVector<Value> ranks, rankDiffs; |
256 | llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) { |
257 | return lb.create<tensor::DimOp>(v, zero); |
258 | })); |
259 | |
260 | // Find the maximum rank |
261 | Value maxRank = ranks.front(); |
262 | for (Value v : llvm::drop_begin(RangeOrContainer&: ranks, N: 1)) { |
263 | maxRank = lb.create<arith::MaxUIOp>(v, maxRank); |
264 | } |
265 | |
266 | // Calculate the difference of ranks and the maximum rank for later offsets. |
267 | llvm::append_range(C&: rankDiffs, R: llvm::map_range(C&: ranks, F: [&](Value v) { |
268 | return lb.create<arith::SubIOp>(indexTy, maxRank, v); |
269 | })); |
270 | |
271 | Type i1Ty = rewriter.getI1Type(); |
272 | Value trueVal = |
273 | rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true)); |
274 | |
275 | auto reduceResult = lb.create<ForOp>( |
276 | loc, zero, maxRank, one, ValueRange{trueVal}, |
277 | [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { |
278 | // Find a non-1 dim, if it exists. Note that the first part of this |
279 | // could reuse the Broadcast lowering entirely, but we redo the work |
280 | // here to make optimizations easier between the two loops. |
281 | Value broadcastedDim = getBroadcastedDim( |
282 | ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv); |
283 | |
284 | Value broadcastable = iterArgs[0]; |
285 | for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) { |
286 | Value shape, rankDiff; |
287 | std::tie(shape, rankDiff) = tup; |
288 | Value outOfBounds = b.create<arith::CmpIOp>( |
289 | loc, arith::CmpIPredicate::ult, iv, rankDiff); |
290 | broadcastable = |
291 | b.create<IfOp>( |
292 | loc, outOfBounds, |
293 | [&](OpBuilder &b, Location loc) { |
294 | // Non existent dimensions are always broadcastable |
295 | b.create<scf::YieldOp>(loc, broadcastable); |
296 | }, |
297 | [&](OpBuilder &b, Location loc) { |
298 | // Every value needs to be either 1, or the same non-1 |
299 | // value to be broadcastable in this dim. |
300 | Value operandDimension = |
301 | b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff); |
302 | Value dimensionExtent = b.create<tensor::ExtractOp>( |
303 | loc, shape, ValueRange{operandDimension}); |
304 | |
305 | Value equalOne = b.create<arith::CmpIOp>( |
306 | loc, arith::CmpIPredicate::eq, dimensionExtent, one); |
307 | Value equalBroadcasted = b.create<arith::CmpIOp>( |
308 | loc, arith::CmpIPredicate::eq, dimensionExtent, |
309 | broadcastedDim); |
310 | Value result = b.create<arith::AndIOp>( |
311 | loc, broadcastable, |
312 | b.create<arith::OrIOp>(loc, equalOne, |
313 | equalBroadcasted)); |
314 | b.create<scf::YieldOp>(loc, result); |
315 | }) |
316 | .getResult(0); |
317 | } |
318 | |
319 | b.create<scf::YieldOp>(loc, broadcastable); |
320 | }); |
321 | |
322 | rewriter.replaceOp(op, reduceResult.getResults().front()); |
323 | return success(); |
324 | } |
325 | |
326 | namespace { |
327 | class DimOpConverter : public OpConversionPattern<DimOp> { |
328 | using OpConversionPattern<DimOp>::OpConversionPattern; |
329 | |
330 | LogicalResult |
331 | matchAndRewrite(DimOp op, OpAdaptor adaptor, |
332 | ConversionPatternRewriter &rewriter) const override; |
333 | }; |
334 | } // namespace |
335 | |
336 | LogicalResult |
337 | DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor, |
338 | ConversionPatternRewriter &rewriter) const { |
339 | // Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further |
340 | // lowerings. This can be further optimized if needed to avoid intermediate |
341 | // steps. |
342 | auto shapeOf = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getValue()); |
343 | rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf, |
344 | op.getIndex()); |
345 | return success(); |
346 | } |
347 | |
348 | namespace { |
349 | class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { |
350 | using OpConversionPattern<GetExtentOp>::OpConversionPattern; |
351 | |
352 | LogicalResult |
353 | matchAndRewrite(GetExtentOp op, OpAdaptor adaptor, |
354 | ConversionPatternRewriter &rewriter) const override; |
355 | }; |
356 | } // namespace |
357 | |
358 | LogicalResult GetExtentOpConverter::matchAndRewrite( |
359 | GetExtentOp op, OpAdaptor adaptor, |
360 | ConversionPatternRewriter &rewriter) const { |
361 | // For now, only error-free types are supported by this lowering. |
362 | if (isa<SizeType>(op.getType())) |
363 | return failure(); |
364 | |
365 | // Derive shape extent directly from shape origin if possible. This |
366 | // circumvents the necessity to materialize the shape in memory. |
367 | if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) { |
368 | if (isa<ShapedType>(shapeOfOp.getArg().getType())) { |
369 | rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(), |
370 | adaptor.getDim()); |
371 | return success(); |
372 | } |
373 | } |
374 | |
375 | rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(), |
376 | adaptor.getShape(), |
377 | ValueRange{adaptor.getDim()}); |
378 | return success(); |
379 | } |
380 | |
381 | namespace { |
382 | class RankOpConverter : public OpConversionPattern<shape::RankOp> { |
383 | public: |
384 | using OpConversionPattern<shape::RankOp>::OpConversionPattern; |
385 | |
386 | LogicalResult |
387 | matchAndRewrite(shape::RankOp op, OpAdaptor adaptor, |
388 | ConversionPatternRewriter &rewriter) const override; |
389 | }; |
390 | } // namespace |
391 | |
392 | LogicalResult |
393 | RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor, |
394 | ConversionPatternRewriter &rewriter) const { |
395 | // For now, this lowering supports only error-free types. |
396 | if (isa<SizeType>(op.getType())) |
397 | return failure(); |
398 | |
399 | rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0); |
400 | return success(); |
401 | } |
402 | |
403 | namespace { |
404 | /// Converts `shape.reduce` to `scf.for`. |
405 | struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> { |
406 | public: |
407 | using OpConversionPattern::OpConversionPattern; |
408 | |
409 | LogicalResult |
410 | matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, |
411 | ConversionPatternRewriter &rewriter) const final; |
412 | }; |
413 | } // namespace |
414 | |
415 | LogicalResult |
416 | ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, |
417 | ConversionPatternRewriter &rewriter) const { |
418 | // For now, this lowering is only defined on `tensor<?xindex>` operands. |
419 | if (isa<ShapeType>(op.getShape().getType())) |
420 | return failure(); |
421 | |
422 | auto loc = op.getLoc(); |
423 | |
424 | Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
425 | Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
426 | Type indexTy = rewriter.getIndexType(); |
427 | Value rank = |
428 | rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero); |
429 | |
430 | auto loop = rewriter.create<scf::ForOp>( |
431 | loc, zero, rank, one, op.getInitVals(), |
432 | [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { |
433 | Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv); |
434 | |
435 | SmallVector<Value, 2> mappedValues{iv, extent}; |
436 | mappedValues.append(args.begin(), args.end()); |
437 | |
438 | IRMapping mapping; |
439 | Block *reduceBody = op.getBody(); |
440 | mapping.map(reduceBody->getArguments(), mappedValues); |
441 | for (auto &nested : reduceBody->without_terminator()) |
442 | b.clone(nested, mapping); |
443 | |
444 | SmallVector<Value, 2> mappedResults; |
445 | for (auto result : reduceBody->getTerminator()->getOperands()) |
446 | mappedResults.push_back(mapping.lookup(result)); |
447 | b.create<scf::YieldOp>(loc, mappedResults); |
448 | }); |
449 | |
450 | rewriter.replaceOp(op, loop.getResults()); |
451 | return success(); |
452 | } |
453 | |
454 | namespace { |
455 | /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is |
456 | /// only defined on `tensor<?xindex>` operands. The test for equality first |
457 | /// compares their size and, if equal, checks every extent for equality. |
458 | /// |
459 | /// Example: |
460 | /// |
461 | /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> |
462 | /// |
463 | /// becomes |
464 | /// |
465 | /// %c0 = arith.constant 0 : index |
466 | /// %0 = dim %arg0, %c0 : tensor<?xindex> |
467 | /// %1 = dim %arg1, %c0 : tensor<?xindex> |
468 | /// %2 = arith.cmpi "eq", %0, %1 : index |
469 | /// %result = scf.if %2 -> (i1) { |
470 | /// %c1 = arith.constant 1 : index |
471 | /// %true = arith.constant true |
472 | /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { |
473 | /// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex> |
474 | /// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex> |
475 | /// %7 = arith.cmpi "eq", %5, %6 : index |
476 | /// %8 = arith.andi %arg3, %7 : i1 |
477 | /// scf.yield %8 : i1 |
478 | /// } |
479 | /// scf.yield %4 : i1 |
480 | /// } else { |
481 | /// %false = arith.constant false |
482 | /// scf.yield %false : i1 |
483 | /// } |
484 | /// |
485 | struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> { |
486 | using OpConversionPattern<ShapeEqOp>::OpConversionPattern; |
487 | |
488 | LogicalResult |
489 | matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, |
490 | ConversionPatternRewriter &rewriter) const override; |
491 | }; |
492 | } // namespace |
493 | |
494 | LogicalResult |
495 | ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, |
496 | ConversionPatternRewriter &rewriter) const { |
497 | if (!llvm::all_of(op.getShapes(), |
498 | [](Value v) { return !isa<ShapeType>(v.getType()); })) |
499 | return failure(); |
500 | |
501 | Type i1Ty = rewriter.getI1Type(); |
502 | if (op.getShapes().size() <= 1) { |
503 | rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty, |
504 | rewriter.getBoolAttr(true)); |
505 | return success(); |
506 | } |
507 | |
508 | auto loc = op.getLoc(); |
509 | Type indexTy = rewriter.getIndexType(); |
510 | Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
511 | Value firstShape = adaptor.getShapes().front(); |
512 | Value firstRank = |
513 | rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero); |
514 | Value result = nullptr; |
515 | // Generate a linear sequence of compares, all with firstShape as lhs. |
516 | for (Value shape : adaptor.getShapes().drop_front(1)) { |
517 | Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero); |
518 | Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
519 | firstRank, rank); |
520 | auto same = rewriter.create<IfOp>( |
521 | loc, eqRank, |
522 | [&](OpBuilder &b, Location loc) { |
523 | Value one = b.create<arith::ConstantIndexOp>(loc, 1); |
524 | Value init = |
525 | b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true)); |
526 | auto loop = b.create<scf::ForOp>( |
527 | loc, zero, firstRank, one, ValueRange{init}, |
528 | [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { |
529 | Value conj = args[0]; |
530 | Value lhsExtent = |
531 | b.create<tensor::ExtractOp>(loc, firstShape, iv); |
532 | Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv); |
533 | Value eqExtent = b.create<arith::CmpIOp>( |
534 | loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent); |
535 | Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent); |
536 | b.create<scf::YieldOp>(loc, ValueRange({conjNext})); |
537 | }); |
538 | b.create<scf::YieldOp>(loc, loop.getResults()); |
539 | }, |
540 | [&](OpBuilder &b, Location loc) { |
541 | Value result = |
542 | b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false)); |
543 | b.create<scf::YieldOp>(loc, result); |
544 | }); |
545 | result = !result ? same.getResult(0) |
546 | : rewriter.create<arith::AndIOp>(loc, result, |
547 | same.getResult(0)); |
548 | } |
549 | rewriter.replaceOp(op, result); |
550 | return success(); |
551 | } |
552 | |
553 | namespace { |
554 | class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { |
555 | public: |
556 | using OpConversionPattern<ShapeOfOp>::OpConversionPattern; |
557 | |
558 | LogicalResult |
559 | matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor, |
560 | ConversionPatternRewriter &rewriter) const override; |
561 | }; |
562 | } // namespace |
563 | |
564 | LogicalResult ShapeOfOpConversion::matchAndRewrite( |
565 | ShapeOfOp op, OpAdaptor adaptor, |
566 | ConversionPatternRewriter &rewriter) const { |
567 | |
568 | // For now, only error-free types are supported by this lowering. |
569 | if (isa<ShapeType>(op.getType())) |
570 | return failure(); |
571 | |
572 | // For ranked tensor arguments, lower to `tensor.from_elements`. |
573 | auto loc = op.getLoc(); |
574 | Value tensor = adaptor.getArg(); |
575 | Type tensorTy = tensor.getType(); |
576 | if (isa<RankedTensorType>(Val: tensorTy)) { |
577 | |
578 | // Build values for individual extents. |
579 | SmallVector<Value, 8> extentValues; |
580 | RankedTensorType rankedTensorTy = cast<RankedTensorType>(tensorTy); |
581 | int64_t rank = rankedTensorTy.getRank(); |
582 | for (int64_t i = 0; i < rank; i++) { |
583 | if (rankedTensorTy.isDynamicDim(i)) { |
584 | Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i); |
585 | extentValues.push_back(Elt: extent); |
586 | } else { |
587 | Value extent = rewriter.create<arith::ConstantIndexOp>( |
588 | loc, rankedTensorTy.getDimSize(i)); |
589 | extentValues.push_back(Elt: extent); |
590 | } |
591 | } |
592 | |
593 | // Materialize extent tensor. |
594 | Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>( |
595 | loc, RankedTensorType::get({rank}, rewriter.getIndexType()), |
596 | extentValues); |
597 | rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), |
598 | staticExtentTensor); |
599 | return success(); |
600 | } |
601 | |
602 | // Lower to `tensor.generate` otherwise. |
603 | auto *ctx = rewriter.getContext(); |
604 | Value rank = rewriter.create<tensor::RankOp>(loc, tensor); |
605 | rewriter.replaceOpWithNewOp<tensor::GenerateOp>( |
606 | op, getExtentTensorType(ctx), ValueRange{rank}, |
607 | [&](OpBuilder &b, Location loc, ValueRange args) { |
608 | Value dim = args.front(); |
609 | Value extent = b.create<tensor::DimOp>(loc, tensor, dim); |
610 | b.create<tensor::YieldOp>(loc, extent); |
611 | }); |
612 | |
613 | return success(); |
614 | } |
615 | |
616 | namespace { |
617 | class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> { |
618 | public: |
619 | using OpConversionPattern<SplitAtOp>::OpConversionPattern; |
620 | |
621 | LogicalResult |
622 | matchAndRewrite(SplitAtOp op, OpAdaptor adaptor, |
623 | ConversionPatternRewriter &rewriter) const override; |
624 | }; |
625 | } // namespace |
626 | |
627 | LogicalResult SplitAtOpConversion::matchAndRewrite( |
628 | SplitAtOp op, OpAdaptor adaptor, |
629 | ConversionPatternRewriter &rewriter) const { |
630 | // Error conditions are not implemented, only lower if all operands and |
631 | // results are extent tensors. |
632 | if (llvm::any_of(Range: ValueRange{op.getOperand(), op.getHead(), op.getTail()}, |
633 | P: [](Value v) { return isa<ShapeType>(v.getType()); })) |
634 | return failure(); |
635 | |
636 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
637 | Value zero = b.create<arith::ConstantIndexOp>(args: 0); |
638 | Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero); |
639 | |
640 | // index < 0 ? index + rank : index |
641 | Value originalIndex = adaptor.getIndex(); |
642 | Value add = b.create<arith::AddIOp>(originalIndex, rank); |
643 | Value indexIsNegative = |
644 | b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero); |
645 | Value index = b.create<arith::SelectOp>(indexIsNegative, add, originalIndex); |
646 | |
647 | Value one = b.create<arith::ConstantIndexOp>(args: 1); |
648 | Value head = |
649 | b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one); |
650 | Value tailSize = b.create<arith::SubIOp>(rank, index); |
651 | Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index, |
652 | tailSize, one); |
653 | rewriter.replaceOp(op, {head, tail}); |
654 | return success(); |
655 | } |
656 | |
657 | namespace { |
658 | class ToExtentTensorOpConversion |
659 | : public OpConversionPattern<ToExtentTensorOp> { |
660 | public: |
661 | using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; |
662 | |
663 | LogicalResult |
664 | matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor, |
665 | ConversionPatternRewriter &rewriter) const override { |
666 | if (!isa<RankedTensorType>(adaptor.getInput().getType())) |
667 | return rewriter.notifyMatchFailure(op, "input needs to be a tensor" ); |
668 | |
669 | rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), |
670 | adaptor.getInput()); |
671 | return success(); |
672 | } |
673 | }; |
674 | } // namespace |
675 | |
676 | namespace { |
677 | /// Import the Shape Ops to Std Patterns. |
678 | #include "ShapeToStandard.cpp.inc" |
679 | } // namespace |
680 | |
681 | namespace { |
682 | /// Conversion pass. |
683 | class ConvertShapeToStandardPass |
684 | : public impl::ConvertShapeToStandardBase<ConvertShapeToStandardPass> { |
685 | |
686 | void runOnOperation() override; |
687 | }; |
688 | } // namespace |
689 | |
690 | void ConvertShapeToStandardPass::runOnOperation() { |
691 | // Setup target legality. |
692 | MLIRContext &ctx = getContext(); |
693 | ConversionTarget target(ctx); |
694 | target.addLegalDialect<arith::ArithDialect, SCFDialect, |
695 | tensor::TensorDialect>(); |
696 | target.addLegalOp<CstrRequireOp, func::FuncOp, ModuleOp>(); |
697 | |
698 | // Setup conversion patterns. |
699 | RewritePatternSet patterns(&ctx); |
700 | populateShapeToStandardConversionPatterns(patterns); |
701 | |
702 | // Apply conversion. |
703 | auto module = getOperation(); |
704 | if (failed(applyPartialConversion(module, target, std::move(patterns)))) |
705 | signalPassFailure(); |
706 | } |
707 | |
708 | void mlir::populateShapeToStandardConversionPatterns( |
709 | RewritePatternSet &patterns) { |
710 | // clang-format off |
711 | populateWithGenerated(patterns); |
712 | patterns.add< |
713 | AnyOpConversion, |
714 | BinaryOpConversion<AddOp, arith::AddIOp>, |
715 | BinaryOpConversion<MulOp, arith::MulIOp>, |
716 | BroadcastOpConverter, |
717 | ConstShapeOpConverter, |
718 | ConstSizeOpConversion, |
719 | DimOpConverter, |
720 | IsBroadcastableOpConverter, |
721 | GetExtentOpConverter, |
722 | RankOpConverter, |
723 | ReduceOpConverter, |
724 | ShapeEqOpConverter, |
725 | ShapeOfOpConversion, |
726 | SplitAtOpConversion, |
727 | ToExtentTensorOpConversion>(patterns.getContext()); |
728 | // clang-format on |
729 | } |
730 | |
731 | std::unique_ptr<OperationPass<ModuleOp>> |
732 | mlir::createConvertShapeToStandardPass() { |
733 | return std::make_unique<ConvertShapeToStandardPass>(); |
734 | } |
735 | |