diff options
author | Jasmine Tang <jjasmine@igalia.com> | 2025-08-05 15:22:37 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-08-05 15:22:37 -0700 |
commit | 9c6bb180407a7db004624d13d9de108d7cebc73c (patch) | |
tree | eb68266ee42e71cdba932c57cb34da4469581007 /llvm/lib/Analysis/ConstantFolding.cpp | |
parent | 34aed0ed5615583a8f1aaf9c036cc69fa88b3503 (diff) | |
download | llvm-9c6bb180407a7db004624d13d9de108d7cebc73c.zip llvm-9c6bb180407a7db004624d13d9de108d7cebc73c.tar.gz llvm-9c6bb180407a7db004624d13d9de108d7cebc73c.tar.bz2 |
[WebAssembly] Constant fold wasm.dot (#149619)
Constant fold wasm.dot of constant vectors/splats.
Test case added in
`llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll`
Related to https://github.com/llvm/llvm-project/issues/55933
Diffstat (limited to 'llvm/lib/Analysis/ConstantFolding.cpp')
-rw-r--r-- | llvm/lib/Analysis/ConstantFolding.cpp | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index dd98b62..4969528 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -1659,6 +1659,7 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) { case Intrinsic::aarch64_sve_convert_from_svbool: case Intrinsic::wasm_alltrue: case Intrinsic::wasm_anytrue: + case Intrinsic::wasm_dot: // WebAssembly float semantics are always known case Intrinsic::wasm_trunc_signed: case Intrinsic::wasm_trunc_unsigned: @@ -3989,6 +3990,30 @@ static Constant *ConstantFoldFixedVectorCall( } return ConstantVector::get(Result); } + case Intrinsic::wasm_dot: { + unsigned NumElements = + cast<FixedVectorType>(Operands[0]->getType())->getNumElements(); + + assert(NumElements == 8 && Result.size() == 4 && + "wasm dot takes i16x8 and produces i32x4"); + assert(Ty->isIntegerTy()); + int32_t MulVector[8]; + + for (unsigned I = 0; I < NumElements; ++I) { + ConstantInt *Elt0 = + cast<ConstantInt>(Operands[0]->getAggregateElement(I)); + ConstantInt *Elt1 = + cast<ConstantInt>(Operands[1]->getAggregateElement(I)); + + MulVector[I] = Elt0->getSExtValue() * Elt1->getSExtValue(); + } + for (unsigned I = 0; I < Result.size(); I++) { + int32_t IAdd = MulVector[I * 2] + MulVector[I * 2 + 1]; + Result[I] = ConstantInt::get(Ty, IAdd); + } + + return ConstantVector::get(Result); + } default: break; } |