aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/DirectX
diff options
context:
space:
mode:
authorFarzon Lotfi <1802579+farzonl@users.noreply.github.com>2024-03-25 18:01:46 -0400
committerGitHub <noreply@github.com>2024-03-25 18:01:46 -0400
commit060df78cdbbf70d5a6dfff3af1d435a5a811b886 (patch)
tree43becb5b4552de7d4235e10f2022a6e3466ce210 /llvm/lib/Target/DirectX
parent765d4c402fe2ff614a15a762bb7cefe7289663b4 (diff)
downloadllvm-060df78cdbbf70d5a6dfff3af1d435a5a811b886.zip
llvm-060df78cdbbf70d5a6dfff3af1d435a5a811b886.tar.gz
llvm-060df78cdbbf70d5a6dfff3af1d435a5a811b886.tar.bz2
[DXIL] Add Float `Dot` Intrinsic Lowering (#86071)
Completes #83626 - `CGBuiltin.cpp` - modify `getDotProductIntrinsic` to be able to emit `dot2`, `dot3`, and `dot4` intrinsics based on element count - `IntrinsicsDirectX.td` - for floating point add `dot2`, `dot3`, and `dot4` inntrinsics -`DXIL.td` add dxilop intrinsic lowering for `dot2`, `dot3`, & `dot4`. - `DXILOpLowering.cpp` - add vector arg flattening for dot product. - `DXILOpBuilder.h` - modify `createDXILOpCall` to take a smallVector instead of an iterator - `DXILOpBuilder.cpp` - modify `createDXILOpCall` by moving the small vector up to the calling function in `DXILOpLowering.cpp`. - Moving one function up gives us access to the `CallInst` and `Function` which were needed to distinguish the dot product intrinsics and get the operands without using the iterator.
Diffstat (limited to 'llvm/lib/Target/DirectX')
-rw-r--r--llvm/lib/Target/DirectX/DXIL.td9
-rw-r--r--llvm/lib/Target/DirectX/DXILOpBuilder.cpp8
-rw-r--r--llvm/lib/Target/DirectX/DXILOpBuilder.h5
-rw-r--r--llvm/lib/Target/DirectX/DXILOpLowering.cpp55
4 files changed, 67 insertions, 10 deletions
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index f7e69eb..2e6d58e 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -303,6 +303,15 @@ def IMad : DXILOpMapping<48, tertiary, int_dx_imad,
"Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">;
def UMad : DXILOpMapping<49, tertiary, int_dx_umad,
"Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">;
+let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 4)) in
+ def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2,
+ "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1">;
+let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 6)) in
+ def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3,
+ "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2">;
+let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 8)) in
+ def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4,
+ "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3">;
def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
"Reads the thread ID">;
def GroupId : DXILOpMapping<94, groupId, int_dx_group_id,
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 0841ae9..0b3982e 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -254,7 +254,7 @@ namespace dxil {
CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
Type *OverloadTy,
- llvm::iterator_range<Use *> Args) {
+ SmallVector<Value *> Args) {
const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
OverloadKind Kind = getOverloadKind(OverloadTy);
@@ -272,10 +272,8 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy);
DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
}
- SmallVector<Value *> FullArgs;
- FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
- FullArgs.append(Args.begin(), Args.end());
- return B.CreateCall(DXILFn, FullArgs);
+
+ return B.CreateCall(DXILFn, Args);
}
Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index f3abcc6..5babeae 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -13,7 +13,7 @@
#define LLVM_LIB_TARGET_DIRECTX_DXILOPBUILDER_H
#include "DXILConstants.h"
-#include "llvm/ADT/iterator_range.h"
+#include "llvm/ADT/SmallVector.h"
namespace llvm {
class Module;
@@ -35,8 +35,7 @@ public:
/// \param OverloadTy Overload type of the DXIL Op call constructed
/// \return DXIL Op call constructed
CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
- Type *OverloadTy,
- llvm::iterator_range<Use *> Args);
+ Type *OverloadTy, SmallVector<Value *> Args);
Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT);
static const char *getOpCodeName(dxil::OpCode DXILOp);
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 3e334b0..f09e322 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -30,6 +30,48 @@
using namespace llvm;
using namespace llvm::dxil;
+static bool isVectorArgExpansion(Function &F) {
+ switch (F.getIntrinsicID()) {
+ case Intrinsic::dx_dot2:
+ case Intrinsic::dx_dot3:
+ case Intrinsic::dx_dot4:
+ return true;
+ }
+ return false;
+}
+
+static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
+ SmallVector<Value *, 4> ExtractedElements;
+ auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
+ for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
+ Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
+ Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
+ ExtractedElements.push_back(ExtractedElement);
+ }
+ return ExtractedElements;
+}
+
+static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
+ IRBuilder<> &Builder) {
+ // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
+ unsigned NumOperands = Orig->getNumOperands() - 1;
+ assert(NumOperands > 0);
+ Value *Arg0 = Orig->getOperand(0);
+ [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
+ assert(VecArg0);
+ SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
+ for (unsigned I = 1; I < NumOperands; ++I) {
+ Value *Arg = Orig->getOperand(I);
+ [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
+ assert(VecArg);
+ assert(VecArg0->getElementType() == VecArg->getElementType());
+ assert(VecArg0->getNumElements() == VecArg->getNumElements());
+ auto NextOperandList = populateOperands(Arg, Builder);
+ NewOperands.append(NextOperandList.begin(), NextOperandList.end());
+ }
+ return NewOperands;
+}
+
static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
IRBuilder<> B(M.getContext());
DXILOpBuilder DXILB(M, B);
@@ -39,9 +81,18 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
if (!CI)
continue;
+ SmallVector<Value *> Args;
+ Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
+ Args.emplace_back(DXILOpArg);
B.SetInsertPoint(CI);
- CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, F.getReturnType(),
- OverloadTy, CI->args());
+ if (isVectorArgExpansion(F)) {
+ SmallVector<Value *> NewArgs = argVectorFlatten(CI, B);
+ Args.append(NewArgs.begin(), NewArgs.end());
+ } else
+ Args.append(CI->arg_begin(), CI->arg_end());
+
+ CallInst *DXILCI =
+ DXILB.createDXILOpCall(DXILOp, F.getReturnType(), OverloadTy, Args);
CI->replaceAllUsesWith(DXILCI);
CI->eraseFromParent();