1//===- MathToFuncs.cpp - Math to outlined implementation conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "mlir/Dialect/Math/IR/Math.h"
16#include "mlir/Dialect/SCF/IR/SCF.h"
17#include "mlir/Dialect/Utils/IndexingUtils.h"
18#include "mlir/Dialect/Vector/IR/VectorOps.h"
19#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
20#include "mlir/IR/ImplicitLocOpBuilder.h"
21#include "mlir/IR/TypeUtilities.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Transforms/DialectConversion.h"
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/ADT/TypeSwitch.h"
26#include "llvm/Support/Debug.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_CONVERTMATHTOFUNCS
30#include "mlir/Conversion/Passes.h.inc"
31} // namespace mlir
32
33using namespace mlir;
34
35#define DEBUG_TYPE "math-to-funcs"
36#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
37
38namespace {
39// Pattern to convert vector operations to scalar operations.
40template <typename Op>
41struct VecOpToScalarOp : public OpRewritePattern<Op> {
42public:
43 using OpRewritePattern<Op>::OpRewritePattern;
44
45 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
46};
47
48// Callback type for getting pre-generated FuncOp implementing
49// an operation of the given type.
50using GetFuncCallbackTy = function_ref<func::FuncOp(Operation *, Type)>;
51
52// Pattern to convert scalar IPowIOp into a call of outlined
53// software implementation.
54class IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
55public:
56 IPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
57 : OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {}
58
59 /// Convert IPowI into a call to a local function implementing
60 /// the power operation. The local function computes a scalar result,
61 /// so vector forms of IPowI are linearized.
62 LogicalResult matchAndRewrite(math::IPowIOp op,
63 PatternRewriter &rewriter) const final;
64
65private:
66 GetFuncCallbackTy getFuncOpCallback;
67};
68
69// Pattern to convert scalar FPowIOp into a call of outlined
70// software implementation.
71class FPowIOpLowering : public OpRewritePattern<math::FPowIOp> {
72public:
73 FPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
74 : OpRewritePattern<math::FPowIOp>(context), getFuncOpCallback(cb) {}
75
76 /// Convert FPowI into a call to a local function implementing
77 /// the power operation. The local function computes a scalar result,
78 /// so vector forms of FPowI are linearized.
79 LogicalResult matchAndRewrite(math::FPowIOp op,
80 PatternRewriter &rewriter) const final;
81
82private:
83 GetFuncCallbackTy getFuncOpCallback;
84};
85
86// Pattern to convert scalar ctlz into a call of outlined software
87// implementation.
88class CtlzOpLowering : public OpRewritePattern<math::CountLeadingZerosOp> {
89public:
90 CtlzOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
91 : OpRewritePattern<math::CountLeadingZerosOp>(context),
92 getFuncOpCallback(cb) {}
93
94 /// Convert ctlz into a call to a local function implementing
95 /// the count leading zeros operation.
96 LogicalResult matchAndRewrite(math::CountLeadingZerosOp op,
97 PatternRewriter &rewriter) const final;
98
99private:
100 GetFuncCallbackTy getFuncOpCallback;
101};
102} // namespace
103
104template <typename Op>
105LogicalResult
106VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
107 Type opType = op.getType();
108 Location loc = op.getLoc();
109 auto vecType = dyn_cast<VectorType>(Val&: opType);
110
111 if (!vecType)
112 return rewriter.notifyMatchFailure(op, "not a vector operation");
113 if (!vecType.hasRank())
114 return rewriter.notifyMatchFailure(op, "unknown vector rank");
115 ArrayRef<int64_t> shape = vecType.getShape();
116 int64_t numElements = vecType.getNumElements();
117
118 Type resultElementType = vecType.getElementType();
119 Attribute initValueAttr;
120 if (isa<FloatType>(Val: resultElementType))
121 initValueAttr = FloatAttr::get(type: resultElementType, value: 0.0);
122 else
123 initValueAttr = IntegerAttr::get(type: resultElementType, value: 0);
124 Value result = rewriter.create<arith::ConstantOp>(
125 location: loc, args: DenseElementsAttr::get(type: vecType, values: initValueAttr));
126 SmallVector<int64_t> strides = computeStrides(sizes: shape);
127 for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
128 SmallVector<int64_t> positions = delinearize(linearIndex, strides);
129 SmallVector<Value> operands;
130 for (Value input : op->getOperands())
131 operands.push_back(
132 Elt: rewriter.create<vector::ExtractOp>(location: loc, args&: input, args&: positions));
133 Value scalarOp =
134 rewriter.create<Op>(loc, vecType.getElementType(), operands);
135 result =
136 rewriter.create<vector::InsertOp>(location: loc, args&: scalarOp, args&: result, args&: positions);
137 }
138 rewriter.replaceOp(op, result);
139 return success();
140}
141
142static FunctionType getElementalFuncTypeForOp(Operation *op) {
143 SmallVector<Type, 1> resultTys(op->getNumResults());
144 SmallVector<Type, 2> inputTys(op->getNumOperands());
145 std::transform(first: op->result_type_begin(), last: op->result_type_end(),
146 result: resultTys.begin(),
147 unary_op: [](Type ty) { return getElementTypeOrSelf(type: ty); });
148 std::transform(first: op->operand_type_begin(), last: op->operand_type_end(),
149 result: inputTys.begin(),
150 unary_op: [](Type ty) { return getElementTypeOrSelf(type: ty); });
151 return FunctionType::get(context: op->getContext(), inputs: inputTys, results: resultTys);
152}
153
154/// Create linkonce_odr function to implement the power function with
155/// the given \p elementType type inside \p module. The \p elementType
156/// must be IntegerType, an the created function has
157/// 'IntegerType (*)(IntegerType, IntegerType)' function type.
158///
159/// template <typename T>
160/// T __mlir_math_ipowi_*(T b, T p) {
161/// if (p == T(0))
162/// return T(1);
163/// if (p < T(0)) {
164/// if (b == T(0))
165/// return T(1) / T(0); // trigger div-by-zero
166/// if (b == T(1))
167/// return T(1);
168/// if (b == T(-1)) {
169/// if (p & T(1))
170/// return T(-1);
171/// return T(1);
172/// }
173/// return T(0);
174/// }
175/// T result = T(1);
176/// while (true) {
177/// if (p & T(1))
178/// result *= b;
179/// p >>= T(1);
180/// if (p == T(0))
181/// return result;
182/// b *= b;
183/// }
184/// }
185static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
186 assert(isa<IntegerType>(elementType) &&
187 "non-integer element type for IPowIOp");
188
189 ImplicitLocOpBuilder builder =
190 ImplicitLocOpBuilder::atBlockEnd(loc: module->getLoc(), block: module->getBody());
191
192 std::string funcName("__mlir_math_ipowi");
193 llvm::raw_string_ostream nameOS(funcName);
194 nameOS << '_' << elementType;
195
196 FunctionType funcType = FunctionType::get(
197 context: builder.getContext(), inputs: {elementType, elementType}, results: elementType);
198 auto funcOp = builder.create<func::FuncOp>(args&: funcName, args&: funcType);
199 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
200 Attribute linkage =
201 LLVM::LinkageAttr::get(context: builder.getContext(), linkage: inlineLinkage);
202 funcOp->setAttr(name: "llvm.linkage", value: linkage);
203 funcOp.setPrivate();
204
205 Block *entryBlock = funcOp.addEntryBlock();
206 Region *funcBody = entryBlock->getParent();
207
208 Value bArg = funcOp.getArgument(idx: 0);
209 Value pArg = funcOp.getArgument(idx: 1);
210 builder.setInsertionPointToEnd(entryBlock);
211 Value zeroValue = builder.create<arith::ConstantOp>(
212 args&: elementType, args: builder.getIntegerAttr(type: elementType, value: 0));
213 Value oneValue = builder.create<arith::ConstantOp>(
214 args&: elementType, args: builder.getIntegerAttr(type: elementType, value: 1));
215 Value minusOneValue = builder.create<arith::ConstantOp>(
216 args&: elementType,
217 args: builder.getIntegerAttr(type: elementType,
218 value: APInt(elementType.getIntOrFloatBitWidth(), -1ULL,
219 /*isSigned=*/true)));
220
221 // if (p == T(0))
222 // return T(1);
223 auto pIsZero =
224 builder.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq, args&: pArg, args&: zeroValue);
225 Block *thenBlock = builder.createBlock(parent: funcBody);
226 builder.create<func::ReturnOp>(args&: oneValue);
227 Block *fallthroughBlock = builder.createBlock(parent: funcBody);
228 // Set up conditional branch for (p == T(0)).
229 builder.setInsertionPointToEnd(pIsZero->getBlock());
230 builder.create<cf::CondBranchOp>(args&: pIsZero, args&: thenBlock, args&: fallthroughBlock);
231
232 // if (p < T(0)) {
233 builder.setInsertionPointToEnd(fallthroughBlock);
234 auto pIsNeg =
235 builder.create<arith::CmpIOp>(args: arith::CmpIPredicate::sle, args&: pArg, args&: zeroValue);
236 // if (b == T(0))
237 builder.createBlock(parent: funcBody);
238 auto bIsZero =
239 builder.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq, args&: bArg, args&: zeroValue);
240 // return T(1) / T(0);
241 thenBlock = builder.createBlock(parent: funcBody);
242 builder.create<func::ReturnOp>(
243 args: builder.create<arith::DivSIOp>(args&: oneValue, args&: zeroValue).getResult());
244 fallthroughBlock = builder.createBlock(parent: funcBody);
245 // Set up conditional branch for (b == T(0)).
246 builder.setInsertionPointToEnd(bIsZero->getBlock());
247 builder.create<cf::CondBranchOp>(args&: bIsZero, args&: thenBlock, args&: fallthroughBlock);
248
249 // if (b == T(1))
250 builder.setInsertionPointToEnd(fallthroughBlock);
251 auto bIsOne =
252 builder.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq, args&: bArg, args&: oneValue);
253 // return T(1);
254 thenBlock = builder.createBlock(parent: funcBody);
255 builder.create<func::ReturnOp>(args&: oneValue);
256 fallthroughBlock = builder.createBlock(parent: funcBody);
257 // Set up conditional branch for (b == T(1)).
258 builder.setInsertionPointToEnd(bIsOne->getBlock());
259 builder.create<cf::CondBranchOp>(args&: bIsOne, args&: thenBlock, args&: fallthroughBlock);
260
261 // if (b == T(-1)) {
262 builder.setInsertionPointToEnd(fallthroughBlock);
263 auto bIsMinusOne = builder.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq,
264 args&: bArg, args&: minusOneValue);
265 // if (p & T(1))
266 builder.createBlock(parent: funcBody);
267 auto pIsOdd = builder.create<arith::CmpIOp>(
268 args: arith::CmpIPredicate::ne, args: builder.create<arith::AndIOp>(args&: pArg, args&: oneValue),
269 args&: zeroValue);
270 // return T(-1);
271 thenBlock = builder.createBlock(parent: funcBody);
272 builder.create<func::ReturnOp>(args&: minusOneValue);
273 fallthroughBlock = builder.createBlock(parent: funcBody);
274 // Set up conditional branch for (p & T(1)).
275 builder.setInsertionPointToEnd(pIsOdd->getBlock());
276 builder.create<cf::CondBranchOp>(args&: pIsOdd, args&: thenBlock, args&: fallthroughBlock);
277
278 // return T(1);
279 // } // b == T(-1)
280 builder.setInsertionPointToEnd(fallthroughBlock);
281 builder.create<func::ReturnOp>(args&: oneValue);
282 fallthroughBlock = builder.createBlock(parent: funcBody);
283 // Set up conditional branch for (b == T(-1)).
284 builder.setInsertionPointToEnd(bIsMinusOne->getBlock());
285 builder.create<cf::CondBranchOp>(args&: bIsMinusOne, args: pIsOdd->getBlock(),
286 args&: fallthroughBlock);
287
288 // return T(0);
289 // } // (p < T(0))
290 builder.setInsertionPointToEnd(fallthroughBlock);
291 builder.create<func::ReturnOp>(args&: zeroValue);
292 Block *loopHeader = builder.createBlock(
293 parent: funcBody, insertPt: funcBody->end(), argTypes: {elementType, elementType, elementType},
294 locs: {builder.getLoc(), builder.getLoc(), builder.getLoc()});
295 // Set up conditional branch for (p < T(0)).
296 builder.setInsertionPointToEnd(pIsNeg->getBlock());
297 // Set initial values of 'result', 'b' and 'p' for the loop.
298 builder.create<cf::CondBranchOp>(args&: pIsNeg, args: bIsZero->getBlock(), args&: loopHeader,
299 args: ValueRange{oneValue, bArg, pArg});
300
301 // T result = T(1);
302 // while (true) {
303 // if (p & T(1))
304 // result *= b;
305 // p >>= T(1);
306 // if (p == T(0))
307 // return result;
308 // b *= b;
309 // }
310 Value resultTmp = loopHeader->getArgument(i: 0);
311 Value baseTmp = loopHeader->getArgument(i: 1);
312 Value powerTmp = loopHeader->getArgument(i: 2);
313 builder.setInsertionPointToEnd(loopHeader);
314
315 // if (p & T(1))
316 auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
317 args: arith::CmpIPredicate::ne,
318 args: builder.create<arith::AndIOp>(args&: powerTmp, args&: oneValue), args&: zeroValue);
319 thenBlock = builder.createBlock(parent: funcBody);
320 // result *= b;
321 Value newResultTmp = builder.create<arith::MulIOp>(args&: resultTmp, args&: baseTmp);
322 fallthroughBlock = builder.createBlock(parent: funcBody, insertPt: funcBody->end(), argTypes: elementType,
323 locs: builder.getLoc());
324 builder.setInsertionPointToEnd(thenBlock);
325 builder.create<cf::BranchOp>(args&: newResultTmp, args&: fallthroughBlock);
326 // Set up conditional branch for (p & T(1)).
327 builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
328 builder.create<cf::CondBranchOp>(args&: powerTmpIsOdd, args&: thenBlock, args&: fallthroughBlock,
329 args&: resultTmp);
330 // Merged 'result'.
331 newResultTmp = fallthroughBlock->getArgument(i: 0);
332
333 // p >>= T(1);
334 builder.setInsertionPointToEnd(fallthroughBlock);
335 Value newPowerTmp = builder.create<arith::ShRUIOp>(args&: powerTmp, args&: oneValue);
336
337 // if (p == T(0))
338 auto newPowerIsZero = builder.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq,
339 args&: newPowerTmp, args&: zeroValue);
340 // return result;
341 thenBlock = builder.createBlock(parent: funcBody);
342 builder.create<func::ReturnOp>(args&: newResultTmp);
343 fallthroughBlock = builder.createBlock(parent: funcBody);
344 // Set up conditional branch for (p == T(0)).
345 builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
346 builder.create<cf::CondBranchOp>(args&: newPowerIsZero, args&: thenBlock, args&: fallthroughBlock);
347
348 // b *= b;
349 // }
350 builder.setInsertionPointToEnd(fallthroughBlock);
351 Value newBaseTmp = builder.create<arith::MulIOp>(args&: baseTmp, args&: baseTmp);
352 // Pass new values for 'result', 'b' and 'p' to the loop header.
353 builder.create<cf::BranchOp>(
354 args: ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, args&: loopHeader);
355 return funcOp;
356}
357
358/// Convert IPowI into a call to a local function implementing
359/// the power operation. The local function computes a scalar result,
360/// so vector forms of IPowI are linearized.
361LogicalResult
362IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
363 PatternRewriter &rewriter) const {
364 auto baseType = dyn_cast<IntegerType>(Val: op.getOperands()[0].getType());
365
366 if (!baseType)
367 return rewriter.notifyMatchFailure(arg&: op, msg: "non-integer base operand");
368
369 // The outlined software implementation must have been already
370 // generated.
371 func::FuncOp elementFunc = getFuncOpCallback(op, baseType);
372 if (!elementFunc)
373 return rewriter.notifyMatchFailure(arg&: op, msg: "missing software implementation");
374
375 rewriter.replaceOpWithNewOp<func::CallOp>(op, args&: elementFunc, args: op.getOperands());
376 return success();
377}
378
379/// Create linkonce_odr function to implement the power function with
380/// the given \p funcType type inside \p module. The \p funcType must be
381/// 'FloatType (*)(FloatType, IntegerType)' function type.
382///
383/// template <typename T>
384/// Tb __mlir_math_fpowi_*(Tb b, Tp p) {
385/// if (p == Tp{0})
386/// return Tb{1};
387/// bool isNegativePower{p < Tp{0}}
388/// bool isMin{p == std::numeric_limits<Tp>::min()};
389/// if (isMin) {
390/// p = std::numeric_limits<Tp>::max();
391/// } else if (isNegativePower) {
392/// p = -p;
393/// }
394/// Tb result = Tb{1};
395/// Tb origBase = Tb{b};
396/// while (true) {
397/// if (p & Tp{1})
398/// result *= b;
399/// p >>= Tp{1};
400/// if (p == Tp{0})
401/// break;
402/// b *= b;
403/// }
404/// if (isMin) {
405/// result *= origBase;
406/// }
407/// if (isNegativePower) {
408/// result = Tb{1} / result;
409/// }
410/// return result;
411/// }
412static func::FuncOp createElementFPowIFunc(ModuleOp *module,
413 FunctionType funcType) {
414 auto baseType = cast<FloatType>(Val: funcType.getInput(i: 0));
415 auto powType = cast<IntegerType>(Val: funcType.getInput(i: 1));
416 ImplicitLocOpBuilder builder =
417 ImplicitLocOpBuilder::atBlockEnd(loc: module->getLoc(), block: module->getBody());
418
419 std::string funcName("__mlir_math_fpowi");
420 llvm::raw_string_ostream nameOS(funcName);
421 nameOS << '_' << baseType;
422 nameOS << '_' << powType;
423 auto funcOp = builder.create<func::FuncOp>(args&: funcName, args&: funcType);
424 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
425 Attribute linkage =
426 LLVM::LinkageAttr::get(context: builder.getContext(), linkage: inlineLinkage);
427 funcOp->setAttr(name: "llvm.linkage", value: linkage);
428 funcOp.setPrivate();
429
430 Block *entryBlock = funcOp.addEntryBlock();
431 Region *funcBody = entryBlock->getParent();
432
433 Value bArg = funcOp.getArgument(idx: 0);
434 Value pArg = funcOp.getArgument(idx: 1);
435 builder.setInsertionPointToEnd(entryBlock);
436 Value oneBValue = builder.create<arith::ConstantOp>(
437 args&: baseType, args: builder.getFloatAttr(type: baseType, value: 1.0));
438 Value zeroPValue = builder.create<arith::ConstantOp>(
439 args&: powType, args: builder.getIntegerAttr(type: powType, value: 0));
440 Value onePValue = builder.create<arith::ConstantOp>(
441 args&: powType, args: builder.getIntegerAttr(type: powType, value: 1));
442 Value minPValue = builder.create<arith::ConstantOp>(
443 args&: powType, args: builder.getIntegerAttr(type: powType, value: llvm::APInt::getSignedMinValue(
444 numBits: powType.getWidth())));
445 Value maxPValue = builder.create<arith::ConstantOp>(
446 args&: powType, args: builder.getIntegerAttr(type: powType, value: llvm::APInt::getSignedMaxValue(
447 numBits: powType.getWidth())));
448
449 // if (p == Tp{0})
450 // return Tb{1};
451 auto pIsZero =
452 builder.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq, args&: pArg, args&: zeroPValue);
453 Block *thenBlock = builder.createBlock(parent: funcBody);
454 builder.create<func::ReturnOp>(args&: oneBValue);
455 Block *fallthroughBlock = builder.createBlock(parent: funcBody);
456 // Set up conditional branch for (p == Tp{0}).
457 builder.setInsertionPointToEnd(pIsZero->getBlock());
458 builder.create<cf::CondBranchOp>(args&: pIsZero, args&: thenBlock, args&: fallthroughBlock);
459
460 builder.setInsertionPointToEnd(fallthroughBlock);
461 // bool isNegativePower{p < Tp{0}}
462 auto pIsNeg = builder.create<arith::CmpIOp>(args: arith::CmpIPredicate::sle, args&: pArg,
463 args&: zeroPValue);
464 // bool isMin{p == std::numeric_limits<Tp>::min()};
465 auto pIsMin =
466 builder.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq, args&: pArg, args&: minPValue);
467
468 // if (isMin) {
469 // p = std::numeric_limits<Tp>::max();
470 // } else if (isNegativePower) {
471 // p = -p;
472 // }
473 Value negP = builder.create<arith::SubIOp>(args&: zeroPValue, args&: pArg);
474 auto pInit = builder.create<arith::SelectOp>(args&: pIsNeg, args&: negP, args&: pArg);
475 pInit = builder.create<arith::SelectOp>(args&: pIsMin, args&: maxPValue, args&: pInit);
476
477 // Tb result = Tb{1};
478 // Tb origBase = Tb{b};
479 // while (true) {
480 // if (p & Tp{1})
481 // result *= b;
482 // p >>= Tp{1};
483 // if (p == Tp{0})
484 // break;
485 // b *= b;
486 // }
487 Block *loopHeader = builder.createBlock(
488 parent: funcBody, insertPt: funcBody->end(), argTypes: {baseType, baseType, powType},
489 locs: {builder.getLoc(), builder.getLoc(), builder.getLoc()});
490 // Set initial values of 'result', 'b' and 'p' for the loop.
491 builder.setInsertionPointToEnd(pInit->getBlock());
492 builder.create<cf::BranchOp>(args&: loopHeader, args: ValueRange{oneBValue, bArg, pInit});
493
494 // Create loop body.
495 Value resultTmp = loopHeader->getArgument(i: 0);
496 Value baseTmp = loopHeader->getArgument(i: 1);
497 Value powerTmp = loopHeader->getArgument(i: 2);
498 builder.setInsertionPointToEnd(loopHeader);
499
500 // if (p & Tp{1})
501 auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
502 args: arith::CmpIPredicate::ne,
503 args: builder.create<arith::AndIOp>(args&: powerTmp, args&: onePValue), args&: zeroPValue);
504 thenBlock = builder.createBlock(parent: funcBody);
505 // result *= b;
506 Value newResultTmp = builder.create<arith::MulFOp>(args&: resultTmp, args&: baseTmp);
507 fallthroughBlock = builder.createBlock(parent: funcBody, insertPt: funcBody->end(), argTypes: baseType,
508 locs: builder.getLoc());
509 builder.setInsertionPointToEnd(thenBlock);
510 builder.create<cf::BranchOp>(args&: newResultTmp, args&: fallthroughBlock);
511 // Set up conditional branch for (p & Tp{1}).
512 builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
513 builder.create<cf::CondBranchOp>(args&: powerTmpIsOdd, args&: thenBlock, args&: fallthroughBlock,
514 args&: resultTmp);
515 // Merged 'result'.
516 newResultTmp = fallthroughBlock->getArgument(i: 0);
517
518 // p >>= Tp{1};
519 builder.setInsertionPointToEnd(fallthroughBlock);
520 Value newPowerTmp = builder.create<arith::ShRUIOp>(args&: powerTmp, args&: onePValue);
521
522 // if (p == Tp{0})
523 auto newPowerIsZero = builder.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq,
524 args&: newPowerTmp, args&: zeroPValue);
525 // break;
526 //
527 // The conditional branch is finalized below with a jump to
528 // the loop exit block.
529 fallthroughBlock = builder.createBlock(parent: funcBody);
530
531 // b *= b;
532 // }
533 builder.setInsertionPointToEnd(fallthroughBlock);
534 Value newBaseTmp = builder.create<arith::MulFOp>(args&: baseTmp, args&: baseTmp);
535 // Pass new values for 'result', 'b' and 'p' to the loop header.
536 builder.create<cf::BranchOp>(
537 args: ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, args&: loopHeader);
538
539 // Set up conditional branch for early loop exit:
540 // if (p == Tp{0})
541 // break;
542 Block *loopExit = builder.createBlock(parent: funcBody, insertPt: funcBody->end(), argTypes: baseType,
543 locs: builder.getLoc());
544 builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
545 builder.create<cf::CondBranchOp>(args&: newPowerIsZero, args&: loopExit, args&: newResultTmp,
546 args&: fallthroughBlock, args: ValueRange{});
547
548 // if (isMin) {
549 // result *= origBase;
550 // }
551 newResultTmp = loopExit->getArgument(i: 0);
552 thenBlock = builder.createBlock(parent: funcBody);
553 fallthroughBlock = builder.createBlock(parent: funcBody, insertPt: funcBody->end(), argTypes: baseType,
554 locs: builder.getLoc());
555 builder.setInsertionPointToEnd(loopExit);
556 builder.create<cf::CondBranchOp>(args&: pIsMin, args&: thenBlock, args&: fallthroughBlock,
557 args&: newResultTmp);
558 builder.setInsertionPointToEnd(thenBlock);
559 newResultTmp = builder.create<arith::MulFOp>(args&: newResultTmp, args&: bArg);
560 builder.create<cf::BranchOp>(args&: newResultTmp, args&: fallthroughBlock);
561
562 /// if (isNegativePower) {
563 /// result = Tb{1} / result;
564 /// }
565 newResultTmp = fallthroughBlock->getArgument(i: 0);
566 thenBlock = builder.createBlock(parent: funcBody);
567 Block *returnBlock = builder.createBlock(parent: funcBody, insertPt: funcBody->end(), argTypes: baseType,
568 locs: builder.getLoc());
569 builder.setInsertionPointToEnd(fallthroughBlock);
570 builder.create<cf::CondBranchOp>(args&: pIsNeg, args&: thenBlock, args&: returnBlock,
571 args&: newResultTmp);
572 builder.setInsertionPointToEnd(thenBlock);
573 newResultTmp = builder.create<arith::DivFOp>(args&: oneBValue, args&: newResultTmp);
574 builder.create<cf::BranchOp>(args&: newResultTmp, args&: returnBlock);
575
576 // return result;
577 builder.setInsertionPointToEnd(returnBlock);
578 builder.create<func::ReturnOp>(args: returnBlock->getArgument(i: 0));
579
580 return funcOp;
581}
582
583/// Convert FPowI into a call to a local function implementing
584/// the power operation. The local function computes a scalar result,
585/// so vector forms of FPowI are linearized.
586LogicalResult
587FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
588 PatternRewriter &rewriter) const {
589 if (isa<VectorType>(Val: op.getType()))
590 return rewriter.notifyMatchFailure(arg&: op, msg: "non-scalar operation");
591
592 FunctionType funcType = getElementalFuncTypeForOp(op);
593
594 // The outlined software implementation must have been already
595 // generated.
596 func::FuncOp elementFunc = getFuncOpCallback(op, funcType);
597 if (!elementFunc)
598 return rewriter.notifyMatchFailure(arg&: op, msg: "missing software implementation");
599
600 rewriter.replaceOpWithNewOp<func::CallOp>(op, args&: elementFunc, args: op.getOperands());
601 return success();
602}
603
604/// Create function to implement the ctlz function the given \p elementType type
605/// inside \p module. The \p elementType must be IntegerType, an the created
606/// function has 'IntegerType (*)(IntegerType)' function type.
607///
608/// template <typename T>
609/// T __mlir_math_ctlz_*(T x) {
610/// bits = sizeof(x) * 8;
611/// if (x == 0)
612/// return bits;
613///
614/// uint32_t n = 0;
615/// for (int i = 1; i < bits; ++i) {
616/// if (x < 0) continue;
617/// n++;
618/// x <<= 1;
619/// }
620/// return n;
621/// }
622///
623/// Converts to (for i32):
624///
625/// func.func private @__mlir_math_ctlz_i32(%arg: i32) -> i32 {
626/// %c_32 = arith.constant 32 : index
627/// %c_0 = arith.constant 0 : i32
628/// %arg_eq_zero = arith.cmpi eq, %arg, %c_0 : i1
629/// %out = scf.if %arg_eq_zero {
630/// scf.yield %c_32 : i32
631/// } else {
632/// %c_1index = arith.constant 1 : index
633/// %c_1i32 = arith.constant 1 : i32
634/// %n = arith.constant 0 : i32
635/// %arg_out, %n_out = scf.for %i = %c_1index to %c_32 step %c_1index
636/// iter_args(%arg_iter = %arg, %n_iter = %n) -> (i32, i32) {
637/// %cond = arith.cmpi slt, %arg_iter, %c_0 : i32
638/// %yield_val = scf.if %cond {
639/// scf.yield %arg_iter, %n_iter : i32, i32
640/// } else {
641/// %arg_next = arith.shli %arg_iter, %c_1i32 : i32
642/// %n_next = arith.addi %n_iter, %c_1i32 : i32
643/// scf.yield %arg_next, %n_next : i32, i32
644/// }
645/// scf.yield %yield_val: i32, i32
646/// }
647/// scf.yield %n_out : i32
648/// }
649/// return %out: i32
650/// }
651static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
652 if (!isa<IntegerType>(Val: elementType)) {
653 LLVM_DEBUG({
654 DBGS() << "non-integer element type for CtlzFunc; type was: ";
655 elementType.print(llvm::dbgs());
656 });
657 llvm_unreachable("non-integer element type");
658 }
659 int64_t bitWidth = elementType.getIntOrFloatBitWidth();
660
661 Location loc = module->getLoc();
662 ImplicitLocOpBuilder builder =
663 ImplicitLocOpBuilder::atBlockEnd(loc, block: module->getBody());
664
665 std::string funcName("__mlir_math_ctlz");
666 llvm::raw_string_ostream nameOS(funcName);
667 nameOS << '_' << elementType;
668 FunctionType funcType =
669 FunctionType::get(context: builder.getContext(), inputs: {elementType}, results: elementType);
670 auto funcOp = builder.create<func::FuncOp>(args&: funcName, args&: funcType);
671
672 // LinkonceODR ensures that there is only one implementation of this function
673 // across all math.ctlz functions that are lowered in this way.
674 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
675 Attribute linkage =
676 LLVM::LinkageAttr::get(context: builder.getContext(), linkage: inlineLinkage);
677 funcOp->setAttr(name: "llvm.linkage", value: linkage);
678 funcOp.setPrivate();
679
680 // set the insertion point to the start of the function
681 Block *funcBody = funcOp.addEntryBlock();
682 builder.setInsertionPointToStart(funcBody);
683
684 Value arg = funcOp.getArgument(idx: 0);
685 Type indexType = builder.getIndexType();
686 Value bitWidthValue = builder.create<arith::ConstantOp>(
687 args&: elementType, args: builder.getIntegerAttr(type: elementType, value: bitWidth));
688 Value zeroValue = builder.create<arith::ConstantOp>(
689 args&: elementType, args: builder.getIntegerAttr(type: elementType, value: 0));
690
691 Value inputEqZero =
692 builder.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq, args&: arg, args&: zeroValue);
693
694 // if input == 0, return bit width, else enter loop.
695 scf::IfOp ifOp = builder.create<scf::IfOp>(
696 args&: elementType, args&: inputEqZero, /*addThenBlock=*/args: true, /*addElseBlock=*/args: true);
697 ifOp.getThenBodyBuilder().create<scf::YieldOp>(location: loc, args&: bitWidthValue);
698
699 auto elseBuilder =
700 ImplicitLocOpBuilder::atBlockEnd(loc, block: &ifOp.getElseRegion().front());
701
702 Value oneIndex = elseBuilder.create<arith::ConstantOp>(
703 args&: indexType, args: elseBuilder.getIndexAttr(value: 1));
704 Value oneValue = elseBuilder.create<arith::ConstantOp>(
705 args&: elementType, args: elseBuilder.getIntegerAttr(type: elementType, value: 1));
706 Value bitWidthIndex = elseBuilder.create<arith::ConstantOp>(
707 args&: indexType, args: elseBuilder.getIndexAttr(value: bitWidth));
708 Value nValue = elseBuilder.create<arith::ConstantOp>(
709 args&: elementType, args: elseBuilder.getIntegerAttr(type: elementType, value: 0));
710
711 auto loop = elseBuilder.create<scf::ForOp>(
712 args&: oneIndex, args&: bitWidthIndex, args&: oneIndex,
713 // Initial values for two loop induction variables, the arg which is being
714 // shifted left in each iteration, and the n value which tracks the count
715 // of leading zeros.
716 args: ValueRange{arg, nValue},
717 // Callback to build the body of the for loop
718 // if (arg < 0) {
719 // continue;
720 // } else {
721 // n++;
722 // arg <<= 1;
723 // }
724 args: [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
725 Value argIter = args[0];
726 Value nIter = args[1];
727
728 Value argIsNonNegative = b.create<arith::CmpIOp>(
729 location: loc, args: arith::CmpIPredicate::slt, args&: argIter, args&: zeroValue);
730 scf::IfOp ifOp = b.create<scf::IfOp>(
731 location: loc, args&: argIsNonNegative,
732 args: [&](OpBuilder &b, Location loc) {
733 // If arg is negative, continue (effectively, break)
734 b.create<scf::YieldOp>(location: loc, args: ValueRange{argIter, nIter});
735 },
736 args: [&](OpBuilder &b, Location loc) {
737 // Otherwise, increment n and shift arg left.
738 Value nNext = b.create<arith::AddIOp>(location: loc, args&: nIter, args&: oneValue);
739 Value argNext = b.create<arith::ShLIOp>(location: loc, args&: argIter, args&: oneValue);
740 b.create<scf::YieldOp>(location: loc, args: ValueRange{argNext, nNext});
741 });
742 b.create<scf::YieldOp>(location: loc, args: ifOp.getResults());
743 });
744 elseBuilder.create<scf::YieldOp>(args: loop.getResult(i: 1));
745
746 builder.create<func::ReturnOp>(args: ifOp.getResult(i: 0));
747 return funcOp;
748}
749
750/// Convert ctlz into a call to a local function implementing the ctlz
751/// operation.
752LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op,
753 PatternRewriter &rewriter) const {
754 if (isa<VectorType>(Val: op.getType()))
755 return rewriter.notifyMatchFailure(arg&: op, msg: "non-scalar operation");
756
757 Type type = getElementTypeOrSelf(type: op.getResult().getType());
758 func::FuncOp elementFunc = getFuncOpCallback(op, type);
759 if (!elementFunc)
760 return rewriter.notifyMatchFailure(op, reasonCallback: [&](::mlir::Diagnostic &diag) {
761 diag << "Missing software implementation for op " << op->getName()
762 << " and type " << type;
763 });
764
765 rewriter.replaceOpWithNewOp<func::CallOp>(op, args&: elementFunc, args: op.getOperand());
766 return success();
767}
768
769namespace {
770struct ConvertMathToFuncsPass
771 : public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
772 ConvertMathToFuncsPass() = default;
773 ConvertMathToFuncsPass(const ConvertMathToFuncsOptions &options)
774 : impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass>(options) {}
775
776 void runOnOperation() override;
777
778private:
779 // Return true, if this FPowI operation must be converted
780 // because the width of its exponent's type is greater than
781 // or equal to minWidthOfFPowIExponent option value.
782 bool isFPowIConvertible(math::FPowIOp op);
783
784 // Reture true, if operation is integer type.
785 bool isConvertible(Operation *op);
786
787 // Generate outlined implementations for power operations
788 // and store them in funcImpls map.
789 void generateOpImplementations();
790
791 // A map between pairs of (operation, type) deduced from operations that this
792 // pass will convert, and the corresponding outlined software implementations
793 // of these operations for the given type.
794 DenseMap<std::pair<OperationName, Type>, func::FuncOp> funcImpls;
795};
796} // namespace
797
798bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {
799 auto expTy =
800 dyn_cast<IntegerType>(Val: getElementTypeOrSelf(type: op.getRhs().getType()));
801 return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);
802}
803
804bool ConvertMathToFuncsPass::isConvertible(Operation *op) {
805 return isa<IntegerType>(Val: getElementTypeOrSelf(type: op->getResult(idx: 0).getType()));
806}
807
808void ConvertMathToFuncsPass::generateOpImplementations() {
809 ModuleOp module = getOperation();
810
811 module.walk(callback: [&](Operation *op) {
812 TypeSwitch<Operation *>(op)
813 .Case<math::CountLeadingZerosOp>(caseFn: [&](math::CountLeadingZerosOp op) {
814 if (!convertCtlz || !isConvertible(op))
815 return;
816 Type resultType = getElementTypeOrSelf(type: op.getResult().getType());
817
818 // Generate the software implementation of this operation,
819 // if it has not been generated yet.
820 auto key = std::pair(op->getName(), resultType);
821 auto entry = funcImpls.try_emplace(Key: key, Args: func::FuncOp{});
822 if (entry.second)
823 entry.first->second = createCtlzFunc(module: &module, elementType: resultType);
824 })
825 .Case<math::IPowIOp>(caseFn: [&](math::IPowIOp op) {
826 if (!isConvertible(op))
827 return;
828
829 Type resultType = getElementTypeOrSelf(type: op.getResult().getType());
830
831 // Generate the software implementation of this operation,
832 // if it has not been generated yet.
833 auto key = std::pair(op->getName(), resultType);
834 auto entry = funcImpls.try_emplace(Key: key, Args: func::FuncOp{});
835 if (entry.second)
836 entry.first->second = createElementIPowIFunc(module: &module, elementType: resultType);
837 })
838 .Case<math::FPowIOp>(caseFn: [&](math::FPowIOp op) {
839 if (!isFPowIConvertible(op))
840 return;
841
842 FunctionType funcType = getElementalFuncTypeForOp(op);
843
844 // Generate the software implementation of this operation,
845 // if it has not been generated yet.
846 // FPowI implementations are mapped via the FunctionType
847 // created from the operation's result and operands.
848 auto key = std::pair(op->getName(), funcType);
849 auto entry = funcImpls.try_emplace(Key: key, Args: func::FuncOp{});
850 if (entry.second)
851 entry.first->second = createElementFPowIFunc(module: &module, funcType);
852 });
853 });
854}
855
856void ConvertMathToFuncsPass::runOnOperation() {
857 ModuleOp module = getOperation();
858
859 // Create outlined implementations for power operations.
860 generateOpImplementations();
861
862 RewritePatternSet patterns(&getContext());
863 patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>,
864 VecOpToScalarOp<math::CountLeadingZerosOp>>(
865 arg: patterns.getContext());
866
867 // For the given Type Returns FuncOp stored in funcImpls map.
868 auto getFuncOpByType = [&](Operation *op, Type type) -> func::FuncOp {
869 auto it = funcImpls.find(Val: std::pair(op->getName(), type));
870 if (it == funcImpls.end())
871 return {};
872
873 return it->second;
874 };
875 patterns.add<IPowIOpLowering, FPowIOpLowering>(arg: patterns.getContext(),
876 args&: getFuncOpByType);
877
878 if (convertCtlz)
879 patterns.add<CtlzOpLowering>(arg: patterns.getContext(), args&: getFuncOpByType);
880
881 ConversionTarget target(getContext());
882 target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
883 func::FuncDialect, scf::SCFDialect,
884 vector::VectorDialect>();
885
886 target.addDynamicallyLegalOp<math::IPowIOp>(
887 callback: [this](math::IPowIOp op) { return !isConvertible(op); });
888 if (convertCtlz) {
889 target.addDynamicallyLegalOp<math::CountLeadingZerosOp>(
890 callback: [this](math::CountLeadingZerosOp op) { return !isConvertible(op); });
891 }
892 target.addDynamicallyLegalOp<math::FPowIOp>(
893 callback: [this](math::FPowIOp op) { return !isFPowIConvertible(op); });
894 if (failed(Result: applyPartialConversion(op: module, target, patterns: std::move(patterns))))
895 signalPassFailure();
896}
897

source code of mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp