1//===- MathToFuncs.cpp - Math to outlined implementation conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "mlir/Dialect/Math/IR/Math.h"
16#include "mlir/Dialect/SCF/IR/SCF.h"
17#include "mlir/Dialect/Utils/IndexingUtils.h"
18#include "mlir/Dialect/Vector/IR/VectorOps.h"
19#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
20#include "mlir/IR/ImplicitLocOpBuilder.h"
21#include "mlir/IR/TypeUtilities.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Transforms/DialectConversion.h"
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/ADT/TypeSwitch.h"
26#include "llvm/Support/Debug.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_CONVERTMATHTOFUNCS
30#include "mlir/Conversion/Passes.h.inc"
31} // namespace mlir
32
33using namespace mlir;
34
35#define DEBUG_TYPE "math-to-funcs"
36#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
37
38namespace {
39// Pattern to convert vector operations to scalar operations.
40template <typename Op>
41struct VecOpToScalarOp : public OpRewritePattern<Op> {
42public:
43 using OpRewritePattern<Op>::OpRewritePattern;
44
45 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
46};
47
48// Callback type for getting pre-generated FuncOp implementing
49// an operation of the given type.
50using GetFuncCallbackTy = function_ref<func::FuncOp(Operation *, Type)>;
51
52// Pattern to convert scalar IPowIOp into a call of outlined
53// software implementation.
54class IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
55public:
56 IPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
57 : OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {}
58
59 /// Convert IPowI into a call to a local function implementing
60 /// the power operation. The local function computes a scalar result,
61 /// so vector forms of IPowI are linearized.
62 LogicalResult matchAndRewrite(math::IPowIOp op,
63 PatternRewriter &rewriter) const final;
64
65private:
66 GetFuncCallbackTy getFuncOpCallback;
67};
68
69// Pattern to convert scalar FPowIOp into a call of outlined
70// software implementation.
71class FPowIOpLowering : public OpRewritePattern<math::FPowIOp> {
72public:
73 FPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
74 : OpRewritePattern<math::FPowIOp>(context), getFuncOpCallback(cb) {}
75
76 /// Convert FPowI into a call to a local function implementing
77 /// the power operation. The local function computes a scalar result,
78 /// so vector forms of FPowI are linearized.
79 LogicalResult matchAndRewrite(math::FPowIOp op,
80 PatternRewriter &rewriter) const final;
81
82private:
83 GetFuncCallbackTy getFuncOpCallback;
84};
85
86// Pattern to convert scalar ctlz into a call of outlined software
87// implementation.
88class CtlzOpLowering : public OpRewritePattern<math::CountLeadingZerosOp> {
89public:
90 CtlzOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
91 : OpRewritePattern<math::CountLeadingZerosOp>(context),
92 getFuncOpCallback(cb) {}
93
94 /// Convert ctlz into a call to a local function implementing
95 /// the count leading zeros operation.
96 LogicalResult matchAndRewrite(math::CountLeadingZerosOp op,
97 PatternRewriter &rewriter) const final;
98
99private:
100 GetFuncCallbackTy getFuncOpCallback;
101};
102} // namespace
103
104template <typename Op>
105LogicalResult
106VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
107 Type opType = op.getType();
108 Location loc = op.getLoc();
109 auto vecType = dyn_cast<VectorType>(opType);
110
111 if (!vecType)
112 return rewriter.notifyMatchFailure(op, "not a vector operation");
113 if (!vecType.hasRank())
114 return rewriter.notifyMatchFailure(op, "unknown vector rank");
115 ArrayRef<int64_t> shape = vecType.getShape();
116 int64_t numElements = vecType.getNumElements();
117
118 Type resultElementType = vecType.getElementType();
119 Attribute initValueAttr;
120 if (isa<FloatType>(resultElementType))
121 initValueAttr = FloatAttr::get(resultElementType, 0.0);
122 else
123 initValueAttr = IntegerAttr::get(resultElementType, 0);
124 Value result = rewriter.create<arith::ConstantOp>(
125 loc, DenseElementsAttr::get(vecType, initValueAttr));
126 SmallVector<int64_t> strides = computeStrides(sizes: shape);
127 for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
128 SmallVector<int64_t> positions = delinearize(linearIndex, strides);
129 SmallVector<Value> operands;
130 for (Value input : op->getOperands())
131 operands.push_back(
132 rewriter.create<vector::ExtractOp>(loc, input, positions));
133 Value scalarOp =
134 rewriter.create<Op>(loc, vecType.getElementType(), operands);
135 result =
136 rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
137 }
138 rewriter.replaceOp(op, result);
139 return success();
140}
141
142static FunctionType getElementalFuncTypeForOp(Operation *op) {
143 SmallVector<Type, 1> resultTys(op->getNumResults());
144 SmallVector<Type, 2> inputTys(op->getNumOperands());
145 std::transform(first: op->result_type_begin(), last: op->result_type_end(),
146 result: resultTys.begin(),
147 unary_op: [](Type ty) { return getElementTypeOrSelf(type: ty); });
148 std::transform(first: op->operand_type_begin(), last: op->operand_type_end(),
149 result: inputTys.begin(),
150 unary_op: [](Type ty) { return getElementTypeOrSelf(type: ty); });
151 return FunctionType::get(op->getContext(), inputTys, resultTys);
152}
153
154/// Create linkonce_odr function to implement the power function with
155/// the given \p elementType type inside \p module. The \p elementType
156/// must be IntegerType, an the created function has
157/// 'IntegerType (*)(IntegerType, IntegerType)' function type.
158///
159/// template <typename T>
160/// T __mlir_math_ipowi_*(T b, T p) {
161/// if (p == T(0))
162/// return T(1);
163/// if (p < T(0)) {
164/// if (b == T(0))
165/// return T(1) / T(0); // trigger div-by-zero
166/// if (b == T(1))
167/// return T(1);
168/// if (b == T(-1)) {
169/// if (p & T(1))
170/// return T(-1);
171/// return T(1);
172/// }
173/// return T(0);
174/// }
175/// T result = T(1);
176/// while (true) {
177/// if (p & T(1))
178/// result *= b;
179/// p >>= T(1);
180/// if (p == T(0))
181/// return result;
182/// b *= b;
183/// }
184/// }
185static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
186 assert(isa<IntegerType>(elementType) &&
187 "non-integer element type for IPowIOp");
188
189 ImplicitLocOpBuilder builder =
190 ImplicitLocOpBuilder::atBlockEnd(loc: module->getLoc(), block: module->getBody());
191
192 std::string funcName("__mlir_math_ipowi");
193 llvm::raw_string_ostream nameOS(funcName);
194 nameOS << '_' << elementType;
195
196 FunctionType funcType = FunctionType::get(
197 builder.getContext(), {elementType, elementType}, elementType);
198 auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
199 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
200 Attribute linkage =
201 LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
202 funcOp->setAttr("llvm.linkage", linkage);
203 funcOp.setPrivate();
204
205 Block *entryBlock = funcOp.addEntryBlock();
206 Region *funcBody = entryBlock->getParent();
207
208 Value bArg = funcOp.getArgument(0);
209 Value pArg = funcOp.getArgument(1);
210 builder.setInsertionPointToEnd(entryBlock);
211 Value zeroValue = builder.create<arith::ConstantOp>(
212 elementType, builder.getIntegerAttr(elementType, 0));
213 Value oneValue = builder.create<arith::ConstantOp>(
214 elementType, builder.getIntegerAttr(elementType, 1));
215 Value minusOneValue = builder.create<arith::ConstantOp>(
216 elementType,
217 builder.getIntegerAttr(elementType,
218 APInt(elementType.getIntOrFloatBitWidth(), -1ULL,
219 /*isSigned=*/true)));
220
221 // if (p == T(0))
222 // return T(1);
223 auto pIsZero =
224 builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroValue);
225 Block *thenBlock = builder.createBlock(parent: funcBody);
226 builder.create<func::ReturnOp>(oneValue);
227 Block *fallthroughBlock = builder.createBlock(parent: funcBody);
228 // Set up conditional branch for (p == T(0)).
229 builder.setInsertionPointToEnd(pIsZero->getBlock());
230 builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
231
232 // if (p < T(0)) {
233 builder.setInsertionPointToEnd(fallthroughBlock);
234 auto pIsNeg =
235 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, zeroValue);
236 // if (b == T(0))
237 builder.createBlock(parent: funcBody);
238 auto bIsZero =
239 builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, zeroValue);
240 // return T(1) / T(0);
241 thenBlock = builder.createBlock(parent: funcBody);
242 builder.create<func::ReturnOp>(
243 builder.create<arith::DivSIOp>(oneValue, zeroValue).getResult());
244 fallthroughBlock = builder.createBlock(parent: funcBody);
245 // Set up conditional branch for (b == T(0)).
246 builder.setInsertionPointToEnd(bIsZero->getBlock());
247 builder.create<cf::CondBranchOp>(bIsZero, thenBlock, fallthroughBlock);
248
249 // if (b == T(1))
250 builder.setInsertionPointToEnd(fallthroughBlock);
251 auto bIsOne =
252 builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, oneValue);
253 // return T(1);
254 thenBlock = builder.createBlock(parent: funcBody);
255 builder.create<func::ReturnOp>(oneValue);
256 fallthroughBlock = builder.createBlock(parent: funcBody);
257 // Set up conditional branch for (b == T(1)).
258 builder.setInsertionPointToEnd(bIsOne->getBlock());
259 builder.create<cf::CondBranchOp>(bIsOne, thenBlock, fallthroughBlock);
260
261 // if (b == T(-1)) {
262 builder.setInsertionPointToEnd(fallthroughBlock);
263 auto bIsMinusOne = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
264 bArg, minusOneValue);
265 // if (p & T(1))
266 builder.createBlock(parent: funcBody);
267 auto pIsOdd = builder.create<arith::CmpIOp>(
268 arith::CmpIPredicate::ne, builder.create<arith::AndIOp>(pArg, oneValue),
269 zeroValue);
270 // return T(-1);
271 thenBlock = builder.createBlock(parent: funcBody);
272 builder.create<func::ReturnOp>(minusOneValue);
273 fallthroughBlock = builder.createBlock(parent: funcBody);
274 // Set up conditional branch for (p & T(1)).
275 builder.setInsertionPointToEnd(pIsOdd->getBlock());
276 builder.create<cf::CondBranchOp>(pIsOdd, thenBlock, fallthroughBlock);
277
278 // return T(1);
279 // } // b == T(-1)
280 builder.setInsertionPointToEnd(fallthroughBlock);
281 builder.create<func::ReturnOp>(oneValue);
282 fallthroughBlock = builder.createBlock(parent: funcBody);
283 // Set up conditional branch for (b == T(-1)).
284 builder.setInsertionPointToEnd(bIsMinusOne->getBlock());
285 builder.create<cf::CondBranchOp>(bIsMinusOne, pIsOdd->getBlock(),
286 fallthroughBlock);
287
288 // return T(0);
289 // } // (p < T(0))
290 builder.setInsertionPointToEnd(fallthroughBlock);
291 builder.create<func::ReturnOp>(zeroValue);
292 Block *loopHeader = builder.createBlock(
293 parent: funcBody, insertPt: funcBody->end(), argTypes: {elementType, elementType, elementType},
294 locs: {builder.getLoc(), builder.getLoc(), builder.getLoc()});
295 // Set up conditional branch for (p < T(0)).
296 builder.setInsertionPointToEnd(pIsNeg->getBlock());
297 // Set initial values of 'result', 'b' and 'p' for the loop.
298 builder.create<cf::CondBranchOp>(pIsNeg, bIsZero->getBlock(), loopHeader,
299 ValueRange{oneValue, bArg, pArg});
300
301 // T result = T(1);
302 // while (true) {
303 // if (p & T(1))
304 // result *= b;
305 // p >>= T(1);
306 // if (p == T(0))
307 // return result;
308 // b *= b;
309 // }
310 Value resultTmp = loopHeader->getArgument(i: 0);
311 Value baseTmp = loopHeader->getArgument(i: 1);
312 Value powerTmp = loopHeader->getArgument(i: 2);
313 builder.setInsertionPointToEnd(loopHeader);
314
315 // if (p & T(1))
316 auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
317 arith::CmpIPredicate::ne,
318 builder.create<arith::AndIOp>(powerTmp, oneValue), zeroValue);
319 thenBlock = builder.createBlock(parent: funcBody);
320 // result *= b;
321 Value newResultTmp = builder.create<arith::MulIOp>(resultTmp, baseTmp);
322 fallthroughBlock = builder.createBlock(parent: funcBody, insertPt: funcBody->end(), argTypes: elementType,
323 locs: builder.getLoc());
324 builder.setInsertionPointToEnd(thenBlock);
325 builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
326 // Set up conditional branch for (p & T(1)).
327 builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
328 builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
329 resultTmp);
330 // Merged 'result'.
331 newResultTmp = fallthroughBlock->getArgument(i: 0);
332
333 // p >>= T(1);
334 builder.setInsertionPointToEnd(fallthroughBlock);
335 Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, oneValue);
336
337 // if (p == T(0))
338 auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
339 newPowerTmp, zeroValue);
340 // return result;
341 thenBlock = builder.createBlock(parent: funcBody);
342 builder.create<func::ReturnOp>(newResultTmp);
343 fallthroughBlock = builder.createBlock(parent: funcBody);
344 // Set up conditional branch for (p == T(0)).
345 builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
346 builder.create<cf::CondBranchOp>(newPowerIsZero, thenBlock, fallthroughBlock);
347
348 // b *= b;
349 // }
350 builder.setInsertionPointToEnd(fallthroughBlock);
351 Value newBaseTmp = builder.create<arith::MulIOp>(baseTmp, baseTmp);
352 // Pass new values for 'result', 'b' and 'p' to the loop header.
353 builder.create<cf::BranchOp>(
354 ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
355 return funcOp;
356}
357
358/// Convert IPowI into a call to a local function implementing
359/// the power operation. The local function computes a scalar result,
360/// so vector forms of IPowI are linearized.
361LogicalResult
362IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
363 PatternRewriter &rewriter) const {
364 auto baseType = dyn_cast<IntegerType>(op.getOperands()[0].getType());
365
366 if (!baseType)
367 return rewriter.notifyMatchFailure(op, "non-integer base operand");
368
369 // The outlined software implementation must have been already
370 // generated.
371 func::FuncOp elementFunc = getFuncOpCallback(op, baseType);
372 if (!elementFunc)
373 return rewriter.notifyMatchFailure(op, "missing software implementation");
374
375 rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
376 return success();
377}
378
379/// Create linkonce_odr function to implement the power function with
380/// the given \p funcType type inside \p module. The \p funcType must be
381/// 'FloatType (*)(FloatType, IntegerType)' function type.
382///
383/// template <typename T>
384/// Tb __mlir_math_fpowi_*(Tb b, Tp p) {
385/// if (p == Tp{0})
386/// return Tb{1};
387/// bool isNegativePower{p < Tp{0}}
388/// bool isMin{p == std::numeric_limits<Tp>::min()};
389/// if (isMin) {
390/// p = std::numeric_limits<Tp>::max();
391/// } else if (isNegativePower) {
392/// p = -p;
393/// }
394/// Tb result = Tb{1};
395/// Tb origBase = Tb{b};
396/// while (true) {
397/// if (p & Tp{1})
398/// result *= b;
399/// p >>= Tp{1};
400/// if (p == Tp{0})
401/// break;
402/// b *= b;
403/// }
404/// if (isMin) {
405/// result *= origBase;
406/// }
407/// if (isNegativePower) {
408/// result = Tb{1} / result;
409/// }
410/// return result;
411/// }
412static func::FuncOp createElementFPowIFunc(ModuleOp *module,
413 FunctionType funcType) {
414 auto baseType = cast<FloatType>(funcType.getInput(0));
415 auto powType = cast<IntegerType>(funcType.getInput(1));
416 ImplicitLocOpBuilder builder =
417 ImplicitLocOpBuilder::atBlockEnd(loc: module->getLoc(), block: module->getBody());
418
419 std::string funcName("__mlir_math_fpowi");
420 llvm::raw_string_ostream nameOS(funcName);
421 nameOS << '_' << baseType;
422 nameOS << '_' << powType;
423 auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
424 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
425 Attribute linkage =
426 LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
427 funcOp->setAttr("llvm.linkage", linkage);
428 funcOp.setPrivate();
429
430 Block *entryBlock = funcOp.addEntryBlock();
431 Region *funcBody = entryBlock->getParent();
432
433 Value bArg = funcOp.getArgument(0);
434 Value pArg = funcOp.getArgument(1);
435 builder.setInsertionPointToEnd(entryBlock);
436 Value oneBValue = builder.create<arith::ConstantOp>(
437 baseType, builder.getFloatAttr(baseType, 1.0));
438 Value zeroPValue = builder.create<arith::ConstantOp>(
439 powType, builder.getIntegerAttr(powType, 0));
440 Value onePValue = builder.create<arith::ConstantOp>(
441 powType, builder.getIntegerAttr(powType, 1));
442 Value minPValue = builder.create<arith::ConstantOp>(
443 powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMinValue(
444 powType.getWidth())));
445 Value maxPValue = builder.create<arith::ConstantOp>(
446 powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMaxValue(
447 powType.getWidth())));
448
449 // if (p == Tp{0})
450 // return Tb{1};
451 auto pIsZero =
452 builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroPValue);
453 Block *thenBlock = builder.createBlock(parent: funcBody);
454 builder.create<func::ReturnOp>(oneBValue);
455 Block *fallthroughBlock = builder.createBlock(parent: funcBody);
456 // Set up conditional branch for (p == Tp{0}).
457 builder.setInsertionPointToEnd(pIsZero->getBlock());
458 builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
459
460 builder.setInsertionPointToEnd(fallthroughBlock);
461 // bool isNegativePower{p < Tp{0}}
462 auto pIsNeg = builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg,
463 zeroPValue);
464 // bool isMin{p == std::numeric_limits<Tp>::min()};
465 auto pIsMin =
466 builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, minPValue);
467
468 // if (isMin) {
469 // p = std::numeric_limits<Tp>::max();
470 // } else if (isNegativePower) {
471 // p = -p;
472 // }
473 Value negP = builder.create<arith::SubIOp>(zeroPValue, pArg);
474 auto pInit = builder.create<arith::SelectOp>(pIsNeg, negP, pArg);
475 pInit = builder.create<arith::SelectOp>(pIsMin, maxPValue, pInit);
476
477 // Tb result = Tb{1};
478 // Tb origBase = Tb{b};
479 // while (true) {
480 // if (p & Tp{1})
481 // result *= b;
482 // p >>= Tp{1};
483 // if (p == Tp{0})
484 // break;
485 // b *= b;
486 // }
487 Block *loopHeader = builder.createBlock(
488 funcBody, funcBody->end(), {baseType, baseType, powType},
489 {builder.getLoc(), builder.getLoc(), builder.getLoc()});
490 // Set initial values of 'result', 'b' and 'p' for the loop.
491 builder.setInsertionPointToEnd(pInit->getBlock());
492 builder.create<cf::BranchOp>(loopHeader, ValueRange{oneBValue, bArg, pInit});
493
494 // Create loop body.
495 Value resultTmp = loopHeader->getArgument(i: 0);
496 Value baseTmp = loopHeader->getArgument(i: 1);
497 Value powerTmp = loopHeader->getArgument(i: 2);
498 builder.setInsertionPointToEnd(loopHeader);
499
500 // if (p & Tp{1})
501 auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
502 arith::CmpIPredicate::ne,
503 builder.create<arith::AndIOp>(powerTmp, onePValue), zeroPValue);
504 thenBlock = builder.createBlock(parent: funcBody);
505 // result *= b;
506 Value newResultTmp = builder.create<arith::MulFOp>(resultTmp, baseTmp);
507 fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
508 builder.getLoc());
509 builder.setInsertionPointToEnd(thenBlock);
510 builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
511 // Set up conditional branch for (p & Tp{1}).
512 builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
513 builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
514 resultTmp);
515 // Merged 'result'.
516 newResultTmp = fallthroughBlock->getArgument(i: 0);
517
518 // p >>= Tp{1};
519 builder.setInsertionPointToEnd(fallthroughBlock);
520 Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, onePValue);
521
522 // if (p == Tp{0})
523 auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
524 newPowerTmp, zeroPValue);
525 // break;
526 //
527 // The conditional branch is finalized below with a jump to
528 // the loop exit block.
529 fallthroughBlock = builder.createBlock(parent: funcBody);
530
531 // b *= b;
532 // }
533 builder.setInsertionPointToEnd(fallthroughBlock);
534 Value newBaseTmp = builder.create<arith::MulFOp>(baseTmp, baseTmp);
535 // Pass new values for 'result', 'b' and 'p' to the loop header.
536 builder.create<cf::BranchOp>(
537 ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
538
539 // Set up conditional branch for early loop exit:
540 // if (p == Tp{0})
541 // break;
542 Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType,
543 builder.getLoc());
544 builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
545 builder.create<cf::CondBranchOp>(newPowerIsZero, loopExit, newResultTmp,
546 fallthroughBlock, ValueRange{});
547
548 // if (isMin) {
549 // result *= origBase;
550 // }
551 newResultTmp = loopExit->getArgument(i: 0);
552 thenBlock = builder.createBlock(parent: funcBody);
553 fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
554 builder.getLoc());
555 builder.setInsertionPointToEnd(loopExit);
556 builder.create<cf::CondBranchOp>(pIsMin, thenBlock, fallthroughBlock,
557 newResultTmp);
558 builder.setInsertionPointToEnd(thenBlock);
559 newResultTmp = builder.create<arith::MulFOp>(newResultTmp, bArg);
560 builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
561
562 /// if (isNegativePower) {
563 /// result = Tb{1} / result;
564 /// }
565 newResultTmp = fallthroughBlock->getArgument(i: 0);
566 thenBlock = builder.createBlock(parent: funcBody);
567 Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
568 builder.getLoc());
569 builder.setInsertionPointToEnd(fallthroughBlock);
570 builder.create<cf::CondBranchOp>(pIsNeg, thenBlock, returnBlock,
571 newResultTmp);
572 builder.setInsertionPointToEnd(thenBlock);
573 newResultTmp = builder.create<arith::DivFOp>(oneBValue, newResultTmp);
574 builder.create<cf::BranchOp>(newResultTmp, returnBlock);
575
576 // return result;
577 builder.setInsertionPointToEnd(returnBlock);
578 builder.create<func::ReturnOp>(returnBlock->getArgument(0));
579
580 return funcOp;
581}
582
583/// Convert FPowI into a call to a local function implementing
584/// the power operation. The local function computes a scalar result,
585/// so vector forms of FPowI are linearized.
586LogicalResult
587FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
588 PatternRewriter &rewriter) const {
589 if (dyn_cast<VectorType>(op.getType()))
590 return rewriter.notifyMatchFailure(op, "non-scalar operation");
591
592 FunctionType funcType = getElementalFuncTypeForOp(op);
593
594 // The outlined software implementation must have been already
595 // generated.
596 func::FuncOp elementFunc = getFuncOpCallback(op, funcType);
597 if (!elementFunc)
598 return rewriter.notifyMatchFailure(op, "missing software implementation");
599
600 rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
601 return success();
602}
603
604/// Create function to implement the ctlz function the given \p elementType type
605/// inside \p module. The \p elementType must be IntegerType, an the created
606/// function has 'IntegerType (*)(IntegerType)' function type.
607///
608/// template <typename T>
609/// T __mlir_math_ctlz_*(T x) {
610/// bits = sizeof(x) * 8;
611/// if (x == 0)
612/// return bits;
613///
614/// uint32_t n = 0;
615/// for (int i = 1; i < bits; ++i) {
616/// if (x < 0) continue;
617/// n++;
618/// x <<= 1;
619/// }
620/// return n;
621/// }
622///
623/// Converts to (for i32):
624///
625/// func.func private @__mlir_math_ctlz_i32(%arg: i32) -> i32 {
626/// %c_32 = arith.constant 32 : index
627/// %c_0 = arith.constant 0 : i32
628/// %arg_eq_zero = arith.cmpi eq, %arg, %c_0 : i1
629/// %out = scf.if %arg_eq_zero {
630/// scf.yield %c_32 : i32
631/// } else {
632/// %c_1index = arith.constant 1 : index
633/// %c_1i32 = arith.constant 1 : i32
634/// %n = arith.constant 0 : i32
635/// %arg_out, %n_out = scf.for %i = %c_1index to %c_32 step %c_1index
636/// iter_args(%arg_iter = %arg, %n_iter = %n) -> (i32, i32) {
637/// %cond = arith.cmpi slt, %arg_iter, %c_0 : i32
638/// %yield_val = scf.if %cond {
639/// scf.yield %arg_iter, %n_iter : i32, i32
640/// } else {
641/// %arg_next = arith.shli %arg_iter, %c_1i32 : i32
642/// %n_next = arith.addi %n_iter, %c_1i32 : i32
643/// scf.yield %arg_next, %n_next : i32, i32
644/// }
645/// scf.yield %yield_val: i32, i32
646/// }
647/// scf.yield %n_out : i32
648/// }
649/// return %out: i32
650/// }
651static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
652 if (!isa<IntegerType>(Val: elementType)) {
653 LLVM_DEBUG({
654 DBGS() << "non-integer element type for CtlzFunc; type was: ";
655 elementType.print(llvm::dbgs());
656 });
657 llvm_unreachable("non-integer element type");
658 }
659 int64_t bitWidth = elementType.getIntOrFloatBitWidth();
660
661 Location loc = module->getLoc();
662 ImplicitLocOpBuilder builder =
663 ImplicitLocOpBuilder::atBlockEnd(loc, block: module->getBody());
664
665 std::string funcName("__mlir_math_ctlz");
666 llvm::raw_string_ostream nameOS(funcName);
667 nameOS << '_' << elementType;
668 FunctionType funcType =
669 FunctionType::get(builder.getContext(), {elementType}, elementType);
670 auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
671
672 // LinkonceODR ensures that there is only one implementation of this function
673 // across all math.ctlz functions that are lowered in this way.
674 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
675 Attribute linkage =
676 LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
677 funcOp->setAttr("llvm.linkage", linkage);
678 funcOp.setPrivate();
679
680 // set the insertion point to the start of the function
681 Block *funcBody = funcOp.addEntryBlock();
682 builder.setInsertionPointToStart(funcBody);
683
684 Value arg = funcOp.getArgument(0);
685 Type indexType = builder.getIndexType();
686 Value bitWidthValue = builder.create<arith::ConstantOp>(
687 elementType, builder.getIntegerAttr(elementType, bitWidth));
688 Value zeroValue = builder.create<arith::ConstantOp>(
689 elementType, builder.getIntegerAttr(elementType, 0));
690
691 Value inputEqZero =
692 builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, arg, zeroValue);
693
694 // if input == 0, return bit width, else enter loop.
695 scf::IfOp ifOp = builder.create<scf::IfOp>(
696 elementType, inputEqZero, /*addThenBlock=*/true, /*addElseBlock=*/true);
697 ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue);
698
699 auto elseBuilder =
700 ImplicitLocOpBuilder::atBlockEnd(loc, block: &ifOp.getElseRegion().front());
701
702 Value oneIndex = elseBuilder.create<arith::ConstantOp>(
703 indexType, elseBuilder.getIndexAttr(1));
704 Value oneValue = elseBuilder.create<arith::ConstantOp>(
705 elementType, elseBuilder.getIntegerAttr(elementType, 1));
706 Value bitWidthIndex = elseBuilder.create<arith::ConstantOp>(
707 indexType, elseBuilder.getIndexAttr(bitWidth));
708 Value nValue = elseBuilder.create<arith::ConstantOp>(
709 elementType, elseBuilder.getIntegerAttr(elementType, 0));
710
711 auto loop = elseBuilder.create<scf::ForOp>(
712 oneIndex, bitWidthIndex, oneIndex,
713 // Initial values for two loop induction variables, the arg which is being
714 // shifted left in each iteration, and the n value which tracks the count
715 // of leading zeros.
716 ValueRange{arg, nValue},
717 // Callback to build the body of the for loop
718 // if (arg < 0) {
719 // continue;
720 // } else {
721 // n++;
722 // arg <<= 1;
723 // }
724 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
725 Value argIter = args[0];
726 Value nIter = args[1];
727
728 Value argIsNonNegative = b.create<arith::CmpIOp>(
729 loc, arith::CmpIPredicate::slt, argIter, zeroValue);
730 scf::IfOp ifOp = b.create<scf::IfOp>(
731 loc, argIsNonNegative,
732 [&](OpBuilder &b, Location loc) {
733 // If arg is negative, continue (effectively, break)
734 b.create<scf::YieldOp>(loc, ValueRange{argIter, nIter});
735 },
736 [&](OpBuilder &b, Location loc) {
737 // Otherwise, increment n and shift arg left.
738 Value nNext = b.create<arith::AddIOp>(loc, nIter, oneValue);
739 Value argNext = b.create<arith::ShLIOp>(loc, argIter, oneValue);
740 b.create<scf::YieldOp>(loc, ValueRange{argNext, nNext});
741 });
742 b.create<scf::YieldOp>(loc, ifOp.getResults());
743 });
744 elseBuilder.create<scf::YieldOp>(loop.getResult(1));
745
746 builder.create<func::ReturnOp>(ifOp.getResult(0));
747 return funcOp;
748}
749
750/// Convert ctlz into a call to a local function implementing the ctlz
751/// operation.
752LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op,
753 PatternRewriter &rewriter) const {
754 if (dyn_cast<VectorType>(op.getType()))
755 return rewriter.notifyMatchFailure(op, "non-scalar operation");
756
757 Type type = getElementTypeOrSelf(op.getResult().getType());
758 func::FuncOp elementFunc = getFuncOpCallback(op, type);
759 if (!elementFunc)
760 return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
761 diag << "Missing software implementation for op " << op->getName()
762 << " and type " << type;
763 });
764
765 rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperand());
766 return success();
767}
768
769namespace {
770struct ConvertMathToFuncsPass
771 : public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
772 ConvertMathToFuncsPass() = default;
773 ConvertMathToFuncsPass(const ConvertMathToFuncsOptions &options)
774 : impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass>(options) {}
775
776 void runOnOperation() override;
777
778private:
779 // Return true, if this FPowI operation must be converted
780 // because the width of its exponent's type is greater than
781 // or equal to minWidthOfFPowIExponent option value.
782 bool isFPowIConvertible(math::FPowIOp op);
783
784 // Generate outlined implementations for power operations
785 // and store them in funcImpls map.
786 void generateOpImplementations();
787
788 // A map between pairs of (operation, type) deduced from operations that this
789 // pass will convert, and the corresponding outlined software implementations
790 // of these operations for the given type.
791 DenseMap<std::pair<OperationName, Type>, func::FuncOp> funcImpls;
792};
793} // namespace
794
795bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {
796 auto expTy =
797 dyn_cast<IntegerType>(getElementTypeOrSelf(op.getRhs().getType()));
798 return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);
799}
800
801void ConvertMathToFuncsPass::generateOpImplementations() {
802 ModuleOp module = getOperation();
803
804 module.walk([&](Operation *op) {
805 TypeSwitch<Operation *>(op)
806 .Case<math::CountLeadingZerosOp>([&](math::CountLeadingZerosOp op) {
807 if (!convertCtlz)
808 return;
809 Type resultType = getElementTypeOrSelf(op.getResult().getType());
810
811 // Generate the software implementation of this operation,
812 // if it has not been generated yet.
813 auto key = std::pair(op->getName(), resultType);
814 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
815 if (entry.second)
816 entry.first->second = createCtlzFunc(&module, resultType);
817 })
818 .Case<math::IPowIOp>([&](math::IPowIOp op) {
819 Type resultType = getElementTypeOrSelf(op.getResult().getType());
820
821 // Generate the software implementation of this operation,
822 // if it has not been generated yet.
823 auto key = std::pair(op->getName(), resultType);
824 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
825 if (entry.second)
826 entry.first->second = createElementIPowIFunc(&module, resultType);
827 })
828 .Case<math::FPowIOp>([&](math::FPowIOp op) {
829 if (!isFPowIConvertible(op))
830 return;
831
832 FunctionType funcType = getElementalFuncTypeForOp(op);
833
834 // Generate the software implementation of this operation,
835 // if it has not been generated yet.
836 // FPowI implementations are mapped via the FunctionType
837 // created from the operation's result and operands.
838 auto key = std::pair(op->getName(), funcType);
839 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
840 if (entry.second)
841 entry.first->second = createElementFPowIFunc(&module, funcType);
842 });
843 });
844}
845
846void ConvertMathToFuncsPass::runOnOperation() {
847 ModuleOp module = getOperation();
848
849 // Create outlined implementations for power operations.
850 generateOpImplementations();
851
852 RewritePatternSet patterns(&getContext());
853 patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>,
854 VecOpToScalarOp<math::CountLeadingZerosOp>>(
855 patterns.getContext());
856
857 // For the given Type Returns FuncOp stored in funcImpls map.
858 auto getFuncOpByType = [&](Operation *op, Type type) -> func::FuncOp {
859 auto it = funcImpls.find(std::pair(op->getName(), type));
860 if (it == funcImpls.end())
861 return {};
862
863 return it->second;
864 };
865 patterns.add<IPowIOpLowering, FPowIOpLowering>(patterns.getContext(),
866 getFuncOpByType);
867
868 if (convertCtlz)
869 patterns.add<CtlzOpLowering>(patterns.getContext(), getFuncOpByType);
870
871 ConversionTarget target(getContext());
872 target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
873 func::FuncDialect, scf::SCFDialect,
874 vector::VectorDialect>();
875
876 target.addIllegalOp<math::IPowIOp>();
877 if (convertCtlz)
878 target.addIllegalOp<math::CountLeadingZerosOp>();
879 target.addDynamicallyLegalOp<math::FPowIOp>(
880 [this](math::FPowIOp op) { return !isFPowIConvertible(op); });
881 if (failed(applyPartialConversion(module, target, std::move(patterns))))
882 signalPassFailure();
883}
884

source code of mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp