1//===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file This file contains DXIL intrinsic expansions for those that don't have
10// opcodes in DirectX Intermediate Language (DXIL).
11//===----------------------------------------------------------------------===//
12
13#include "DXILIntrinsicExpansion.h"
14#include "DirectX.h"
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/ADT/SmallVector.h"
17#include "llvm/CodeGen/Passes.h"
18#include "llvm/IR/IRBuilder.h"
19#include "llvm/IR/Instruction.h"
20#include "llvm/IR/Instructions.h"
21#include "llvm/IR/Intrinsics.h"
22#include "llvm/IR/IntrinsicsDirectX.h"
23#include "llvm/IR/Module.h"
24#include "llvm/IR/PassManager.h"
25#include "llvm/IR/Type.h"
26#include "llvm/Pass.h"
27#include "llvm/Support/ErrorHandling.h"
28#include "llvm/Support/MathExtras.h"
29
30#define DEBUG_TYPE "dxil-intrinsic-expansion"
31
32using namespace llvm;
33
34static bool isIntrinsicExpansion(Function &F) {
35 switch (F.getIntrinsicID()) {
36 case Intrinsic::abs:
37 case Intrinsic::exp:
38 case Intrinsic::log:
39 case Intrinsic::log10:
40 case Intrinsic::pow:
41 case Intrinsic::dx_any:
42 case Intrinsic::dx_clamp:
43 case Intrinsic::dx_uclamp:
44 case Intrinsic::dx_lerp:
45 case Intrinsic::dx_sdot:
46 case Intrinsic::dx_udot:
47 return true;
48 }
49 return false;
50}
51
52static bool expandAbs(CallInst *Orig) {
53 Value *X = Orig->getOperand(i_nocapture: 0);
54 IRBuilder<> Builder(Orig->getParent());
55 Builder.SetInsertPoint(Orig);
56 Type *Ty = X->getType();
57 Type *EltTy = Ty->getScalarType();
58 Constant *Zero = Ty->isVectorTy()
59 ? ConstantVector::getSplat(
60 EC: ElementCount::getFixed(
61 MinVal: cast<FixedVectorType>(Val: Ty)->getNumElements()),
62 Elt: ConstantInt::get(Ty: EltTy, V: 0))
63 : ConstantInt::get(Ty: EltTy, V: 0);
64 auto *V = Builder.CreateSub(LHS: Zero, RHS: X);
65 auto *MaxCall =
66 Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, "dx.max");
67 Orig->replaceAllUsesWith(V: MaxCall);
68 Orig->eraseFromParent();
69 return true;
70}
71
72static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
73 assert(DotIntrinsic == Intrinsic::dx_sdot ||
74 DotIntrinsic == Intrinsic::dx_udot);
75 Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
76 ? Intrinsic::dx_imad
77 : Intrinsic::dx_umad;
78 Value *A = Orig->getOperand(i_nocapture: 0);
79 Value *B = Orig->getOperand(i_nocapture: 1);
80 Type *ATy = A->getType();
81 Type *BTy = B->getType();
82 assert(ATy->isVectorTy() && BTy->isVectorTy());
83
84 IRBuilder<> Builder(Orig->getParent());
85 Builder.SetInsertPoint(Orig);
86
87 auto *AVec = dyn_cast<FixedVectorType>(Val: A->getType());
88 Value *Elt0 = Builder.CreateExtractElement(Vec: A, Idx: (uint64_t)0);
89 Value *Elt1 = Builder.CreateExtractElement(Vec: B, Idx: (uint64_t)0);
90 Value *Result = Builder.CreateMul(LHS: Elt0, RHS: Elt1);
91 for (unsigned I = 1; I < AVec->getNumElements(); I++) {
92 Elt0 = Builder.CreateExtractElement(Vec: A, Idx: I);
93 Elt1 = Builder.CreateExtractElement(Vec: B, Idx: I);
94 Result = Builder.CreateIntrinsic(RetTy: Result->getType(), ID: MadIntrinsic,
95 Args: ArrayRef<Value *>{Elt0, Elt1, Result},
96 FMFSource: nullptr, Name: "dx.mad");
97 }
98 Orig->replaceAllUsesWith(V: Result);
99 Orig->eraseFromParent();
100 return true;
101}
102
103static bool expandExpIntrinsic(CallInst *Orig) {
104 Value *X = Orig->getOperand(i_nocapture: 0);
105 IRBuilder<> Builder(Orig->getParent());
106 Builder.SetInsertPoint(Orig);
107 Type *Ty = X->getType();
108 Type *EltTy = Ty->getScalarType();
109 Constant *Log2eConst =
110 Ty->isVectorTy() ? ConstantVector::getSplat(
111 EC: ElementCount::getFixed(
112 MinVal: cast<FixedVectorType>(Val: Ty)->getNumElements()),
113 Elt: ConstantFP::get(Ty: EltTy, V: numbers::log2ef))
114 : ConstantFP::get(Ty: EltTy, V: numbers::log2ef);
115 Value *NewX = Builder.CreateFMul(L: Log2eConst, R: X);
116 auto *Exp2Call =
117 Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
118 Exp2Call->setTailCall(Orig->isTailCall());
119 Exp2Call->setAttributes(Orig->getAttributes());
120 Orig->replaceAllUsesWith(V: Exp2Call);
121 Orig->eraseFromParent();
122 return true;
123}
124
125static bool expandAnyIntrinsic(CallInst *Orig) {
126 Value *X = Orig->getOperand(i_nocapture: 0);
127 IRBuilder<> Builder(Orig->getParent());
128 Builder.SetInsertPoint(Orig);
129 Type *Ty = X->getType();
130 Type *EltTy = Ty->getScalarType();
131
132 if (!Ty->isVectorTy()) {
133 Value *Cond = EltTy->isFloatingPointTy()
134 ? Builder.CreateFCmpUNE(LHS: X, RHS: ConstantFP::get(Ty: EltTy, V: 0))
135 : Builder.CreateICmpNE(LHS: X, RHS: ConstantInt::get(Ty: EltTy, V: 0));
136 Orig->replaceAllUsesWith(V: Cond);
137 } else {
138 auto *XVec = dyn_cast<FixedVectorType>(Val: Ty);
139 Value *Cond =
140 EltTy->isFloatingPointTy()
141 ? Builder.CreateFCmpUNE(
142 LHS: X, RHS: ConstantVector::getSplat(
143 EC: ElementCount::getFixed(MinVal: XVec->getNumElements()),
144 Elt: ConstantFP::get(Ty: EltTy, V: 0)))
145 : Builder.CreateICmpNE(
146 LHS: X, RHS: ConstantVector::getSplat(
147 EC: ElementCount::getFixed(MinVal: XVec->getNumElements()),
148 Elt: ConstantInt::get(Ty: EltTy, V: 0)));
149 Value *Result = Builder.CreateExtractElement(Vec: Cond, Idx: (uint64_t)0);
150 for (unsigned I = 1; I < XVec->getNumElements(); I++) {
151 Value *Elt = Builder.CreateExtractElement(Vec: Cond, Idx: I);
152 Result = Builder.CreateOr(LHS: Result, RHS: Elt);
153 }
154 Orig->replaceAllUsesWith(V: Result);
155 }
156 Orig->eraseFromParent();
157 return true;
158}
159
160static bool expandLerpIntrinsic(CallInst *Orig) {
161 Value *X = Orig->getOperand(i_nocapture: 0);
162 Value *Y = Orig->getOperand(i_nocapture: 1);
163 Value *S = Orig->getOperand(i_nocapture: 2);
164 IRBuilder<> Builder(Orig->getParent());
165 Builder.SetInsertPoint(Orig);
166 auto *V = Builder.CreateFSub(L: Y, R: X);
167 V = Builder.CreateFMul(L: S, R: V);
168 auto *Result = Builder.CreateFAdd(L: X, R: V, Name: "dx.lerp");
169 Orig->replaceAllUsesWith(V: Result);
170 Orig->eraseFromParent();
171 return true;
172}
173
174static bool expandLogIntrinsic(CallInst *Orig,
175 float LogConstVal = numbers::ln2f) {
176 Value *X = Orig->getOperand(i_nocapture: 0);
177 IRBuilder<> Builder(Orig->getParent());
178 Builder.SetInsertPoint(Orig);
179 Type *Ty = X->getType();
180 Type *EltTy = Ty->getScalarType();
181 Constant *Ln2Const =
182 Ty->isVectorTy() ? ConstantVector::getSplat(
183 EC: ElementCount::getFixed(
184 MinVal: cast<FixedVectorType>(Val: Ty)->getNumElements()),
185 Elt: ConstantFP::get(Ty: EltTy, V: LogConstVal))
186 : ConstantFP::get(Ty: EltTy, V: LogConstVal);
187 auto *Log2Call =
188 Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
189 Log2Call->setTailCall(Orig->isTailCall());
190 Log2Call->setAttributes(Orig->getAttributes());
191 auto *Result = Builder.CreateFMul(L: Ln2Const, R: Log2Call);
192 Orig->replaceAllUsesWith(V: Result);
193 Orig->eraseFromParent();
194 return true;
195}
196static bool expandLog10Intrinsic(CallInst *Orig) {
197 return expandLogIntrinsic(Orig, LogConstVal: numbers::ln2f / numbers::ln10f);
198}
199
200static bool expandPowIntrinsic(CallInst *Orig) {
201
202 Value *X = Orig->getOperand(i_nocapture: 0);
203 Value *Y = Orig->getOperand(i_nocapture: 1);
204 Type *Ty = X->getType();
205 IRBuilder<> Builder(Orig->getParent());
206 Builder.SetInsertPoint(Orig);
207
208 auto *Log2Call =
209 Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
210 auto *Mul = Builder.CreateFMul(L: Log2Call, R: Y);
211 auto *Exp2Call =
212 Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");
213 Exp2Call->setTailCall(Orig->isTailCall());
214 Exp2Call->setAttributes(Orig->getAttributes());
215 Orig->replaceAllUsesWith(V: Exp2Call);
216 Orig->eraseFromParent();
217 return true;
218}
219
220static Intrinsic::ID getMaxForClamp(Type *ElemTy,
221 Intrinsic::ID ClampIntrinsic) {
222 if (ClampIntrinsic == Intrinsic::dx_uclamp)
223 return Intrinsic::umax;
224 assert(ClampIntrinsic == Intrinsic::dx_clamp);
225 if (ElemTy->isVectorTy())
226 ElemTy = ElemTy->getScalarType();
227 if (ElemTy->isIntegerTy())
228 return Intrinsic::smax;
229 assert(ElemTy->isFloatingPointTy());
230 return Intrinsic::maxnum;
231}
232
233static Intrinsic::ID getMinForClamp(Type *ElemTy,
234 Intrinsic::ID ClampIntrinsic) {
235 if (ClampIntrinsic == Intrinsic::dx_uclamp)
236 return Intrinsic::umin;
237 assert(ClampIntrinsic == Intrinsic::dx_clamp);
238 if (ElemTy->isVectorTy())
239 ElemTy = ElemTy->getScalarType();
240 if (ElemTy->isIntegerTy())
241 return Intrinsic::smin;
242 assert(ElemTy->isFloatingPointTy());
243 return Intrinsic::minnum;
244}
245
246static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
247 Value *X = Orig->getOperand(i_nocapture: 0);
248 Value *Min = Orig->getOperand(i_nocapture: 1);
249 Value *Max = Orig->getOperand(i_nocapture: 2);
250 Type *Ty = X->getType();
251 IRBuilder<> Builder(Orig->getParent());
252 Builder.SetInsertPoint(Orig);
253 auto *MaxCall = Builder.CreateIntrinsic(
254 RetTy: Ty, ID: getMaxForClamp(ElemTy: Ty, ClampIntrinsic), Args: {X, Min}, FMFSource: nullptr, Name: "dx.max");
255 auto *MinCall =
256 Builder.CreateIntrinsic(RetTy: Ty, ID: getMinForClamp(ElemTy: Ty, ClampIntrinsic),
257 Args: {MaxCall, Max}, FMFSource: nullptr, Name: "dx.min");
258
259 Orig->replaceAllUsesWith(V: MinCall);
260 Orig->eraseFromParent();
261 return true;
262}
263
264static bool expandIntrinsic(Function &F, CallInst *Orig) {
265 switch (F.getIntrinsicID()) {
266 case Intrinsic::abs:
267 return expandAbs(Orig);
268 case Intrinsic::exp:
269 return expandExpIntrinsic(Orig);
270 case Intrinsic::log:
271 return expandLogIntrinsic(Orig);
272 case Intrinsic::log10:
273 return expandLog10Intrinsic(Orig);
274 case Intrinsic::pow:
275 return expandPowIntrinsic(Orig);
276 case Intrinsic::dx_any:
277 return expandAnyIntrinsic(Orig);
278 case Intrinsic::dx_uclamp:
279 case Intrinsic::dx_clamp:
280 return expandClampIntrinsic(Orig, ClampIntrinsic: F.getIntrinsicID());
281 case Intrinsic::dx_lerp:
282 return expandLerpIntrinsic(Orig);
283 case Intrinsic::dx_sdot:
284 case Intrinsic::dx_udot:
285 return expandIntegerDot(Orig, DotIntrinsic: F.getIntrinsicID());
286 }
287 return false;
288}
289
290static bool expansionIntrinsics(Module &M) {
291 for (auto &F : make_early_inc_range(Range: M.functions())) {
292 if (!isIntrinsicExpansion(F))
293 continue;
294 bool IntrinsicExpanded = false;
295 for (User *U : make_early_inc_range(Range: F.users())) {
296 auto *IntrinsicCall = dyn_cast<CallInst>(Val: U);
297 if (!IntrinsicCall)
298 continue;
299 IntrinsicExpanded = expandIntrinsic(F, Orig: IntrinsicCall);
300 }
301 if (F.user_empty() && IntrinsicExpanded)
302 F.eraseFromParent();
303 }
304 return true;
305}
306
307PreservedAnalyses DXILIntrinsicExpansion::run(Module &M,
308 ModuleAnalysisManager &) {
309 if (expansionIntrinsics(M))
310 return PreservedAnalyses::none();
311 return PreservedAnalyses::all();
312}
313
314bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) {
315 return expansionIntrinsics(M);
316}
317
318char DXILIntrinsicExpansionLegacy::ID = 0;
319
320INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
321 "DXIL Intrinsic Expansion", false, false)
322INITIALIZE_PASS_END(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
323 "DXIL Intrinsic Expansion", false, false)
324
325ModulePass *llvm::createDXILIntrinsicExpansionLegacyPass() {
326 return new DXILIntrinsicExpansionLegacy();
327}
328

source code of llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp