diff options
author | Farzon Lotfi <1802579+farzonl@users.noreply.github.com> | 2024-03-25 18:01:46 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-25 18:01:46 -0400 |
commit | 060df78cdbbf70d5a6dfff3af1d435a5a811b886 (patch) | |
tree | 43becb5b4552de7d4235e10f2022a6e3466ce210 /llvm/lib/Target/DirectX | |
parent | 765d4c402fe2ff614a15a762bb7cefe7289663b4 (diff) | |
download | llvm-060df78cdbbf70d5a6dfff3af1d435a5a811b886.zip llvm-060df78cdbbf70d5a6dfff3af1d435a5a811b886.tar.gz llvm-060df78cdbbf70d5a6dfff3af1d435a5a811b886.tar.bz2 |
[DXIL] Add Float `Dot` Intrinsic Lowering (#86071)
Completes #83626
- `CGBuiltin.cpp` - modify `getDotProductIntrinsic` to be able to emit
`dot2`, `dot3`, and `dot4` intrinsics based on element count
- `IntrinsicsDirectX.td` - for floating point add `dot2`, `dot3`, and
`dot4` inntrinsics -`DXIL.td` add dxilop intrinsic lowering for `dot2`,
`dot3`, & `dot4`.
- `DXILOpLowering.cpp` - add vector arg flattening for dot product.
- `DXILOpBuilder.h` - modify `createDXILOpCall` to take a smallVector
instead of an iterator
- `DXILOpBuilder.cpp` - modify `createDXILOpCall` by moving the small
vector up to the calling function in `DXILOpLowering.cpp`.
- Moving one function up gives us access to the `CallInst` and
`Function` which were needed to distinguish the dot product intrinsics
and get the operands without using the iterator.
Diffstat (limited to 'llvm/lib/Target/DirectX')
-rw-r--r-- | llvm/lib/Target/DirectX/DXIL.td | 9 | ||||
-rw-r--r-- | llvm/lib/Target/DirectX/DXILOpBuilder.cpp | 8 | ||||
-rw-r--r-- | llvm/lib/Target/DirectX/DXILOpBuilder.h | 5 | ||||
-rw-r--r-- | llvm/lib/Target/DirectX/DXILOpLowering.cpp | 55 |
4 files changed, 67 insertions, 10 deletions
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index f7e69eb..2e6d58e 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -303,6 +303,15 @@ def IMad : DXILOpMapping<48, tertiary, int_dx_imad, "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">; def UMad : DXILOpMapping<49, tertiary, int_dx_umad, "Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">; +let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 4)) in + def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2, + "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1">; +let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 6)) in + def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3, + "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2">; +let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 8)) in + def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4, + "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3">; def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id, "Reads the thread ID">; def GroupId : DXILOpMapping<94, groupId, int_dx_group_id, diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp index 0841ae9..0b3982e 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp @@ -254,7 +254,7 @@ namespace dxil { CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy, Type *OverloadTy, - llvm::iterator_range<Use *> Args) { + SmallVector<Value *> Args) { const OpCodeProperty *Prop = getOpCodeProperty(OpCode); OverloadKind Kind = getOverloadKind(OverloadTy); @@ -272,10 +272,8 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy, FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy); DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT); } - SmallVector<Value *> FullArgs; - FullArgs.emplace_back(B.getInt32((int32_t)OpCode)); - FullArgs.append(Args.begin(), Args.end()); - return B.CreateCall(DXILFn, FullArgs); + + return B.CreateCall(DXILFn, Args); } Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) { diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h index f3abcc6..5babeae 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.h +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h @@ -13,7 +13,7 @@ #define LLVM_LIB_TARGET_DIRECTX_DXILOPBUILDER_H #include "DXILConstants.h" -#include "llvm/ADT/iterator_range.h" +#include "llvm/ADT/SmallVector.h" namespace llvm { class Module; @@ -35,8 +35,7 @@ public: /// \param OverloadTy Overload type of the DXIL Op call constructed /// \return DXIL Op call constructed CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy, - Type *OverloadTy, - llvm::iterator_range<Use *> Args); + Type *OverloadTy, SmallVector<Value *> Args); Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT); static const char *getOpCodeName(dxil::OpCode DXILOp); diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index 3e334b0..f09e322 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -30,6 +30,48 @@ using namespace llvm; using namespace llvm::dxil; +static bool isVectorArgExpansion(Function &F) { + switch (F.getIntrinsicID()) { + case Intrinsic::dx_dot2: + case Intrinsic::dx_dot3: + case Intrinsic::dx_dot4: + return true; + } + return false; +} + +static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) { + SmallVector<Value *, 4> ExtractedElements; + auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType()); + for (unsigned I = 0; I < VecArg->getNumElements(); ++I) { + Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I); + Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index); + ExtractedElements.push_back(ExtractedElement); + } + return ExtractedElements; +} + +static SmallVector<Value *> argVectorFlatten(CallInst *Orig, + IRBuilder<> &Builder) { + // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening. + unsigned NumOperands = Orig->getNumOperands() - 1; + assert(NumOperands > 0); + Value *Arg0 = Orig->getOperand(0); + [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType()); + assert(VecArg0); + SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder); + for (unsigned I = 1; I < NumOperands; ++I) { + Value *Arg = Orig->getOperand(I); + [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType()); + assert(VecArg); + assert(VecArg0->getElementType() == VecArg->getElementType()); + assert(VecArg0->getNumElements() == VecArg->getNumElements()); + auto NextOperandList = populateOperands(Arg, Builder); + NewOperands.append(NextOperandList.begin(), NextOperandList.end()); + } + return NewOperands; +} + static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) { IRBuilder<> B(M.getContext()); DXILOpBuilder DXILB(M, B); @@ -39,9 +81,18 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) { if (!CI) continue; + SmallVector<Value *> Args; + Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp)); + Args.emplace_back(DXILOpArg); B.SetInsertPoint(CI); - CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, F.getReturnType(), - OverloadTy, CI->args()); + if (isVectorArgExpansion(F)) { + SmallVector<Value *> NewArgs = argVectorFlatten(CI, B); + Args.append(NewArgs.begin(), NewArgs.end()); + } else + Args.append(CI->arg_begin(), CI->arg_end()); + + CallInst *DXILCI = + DXILB.createDXILOpCall(DXILOp, F.getReturnType(), OverloadTy, Args); CI->replaceAllUsesWith(DXILCI); CI->eraseFromParent(); |