aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils/LoopUtils.cpp
diff options
context:
space:
mode:
authorSam Parker <sam.parker@arm.com>2024-07-17 09:21:52 +0100
committerGitHub <noreply@github.com>2024-07-17 09:21:52 +0100
commitd28ed29d6bd9f0389092775406fff7e6205d4d5f (patch)
treebe650c3abc641b4da02f1fcdb94c4a125881a691 /llvm/lib/Transforms/Utils/LoopUtils.cpp
parentcaaf8099efa87a7ebca8920971b7d7f719808591 (diff)
downloadllvm-d28ed29d6bd9f0389092775406fff7e6205d4d5f.zip
llvm-d28ed29d6bd9f0389092775406fff7e6205d4d5f.tar.gz
llvm-d28ed29d6bd9f0389092775406fff7e6205d4d5f.tar.bz2
[TTI][WebAssembly] Pairwise reduction expansion (#93948)
WebAssembly doesn't support horizontal operations nor does it have a way of expressing fast-math or reassoc flags, so runtimes are currently unable to use pairwise operations when generating code from the existing shuffle patterns. This patch allows the backend to select which, arbitary, shuffle pattern to be used per reduction intrinsic. The default behaviour is the same as the existing, which is by splitting the vector into a top and bottom half. The other pattern introduced is for a pairwise shuffle. WebAssembly enables pairwise reductions for int/fp add/sub.
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r--llvm/lib/Transforms/Utils/LoopUtils.cpp42
1 files changed, 30 insertions, 12 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index ff93035..4609376 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -1077,7 +1077,9 @@ Value *llvm::getOrderedReduction(IRBuilderBase &Builder, Value *Acc, Value *Src,
// Helper to generate a log2 shuffle reduction.
Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src,
- unsigned Op, RecurKind RdxKind) {
+ unsigned Op,
+ TargetTransformInfo::ReductionShuffle RS,
+ RecurKind RdxKind) {
unsigned VF = cast<FixedVectorType>(Src->getType())->getNumElements();
// VF is a power of 2 so we can emit the reduction using log2(VF) shuffles
// and vector ops, reducing the set of values being computed by half each
@@ -1091,18 +1093,10 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src,
// will never be relevant here. Note that it would be generally unsound to
// propagate these from an intrinsic call to the expansion anyways as we/
// change the order of operations.
- Value *TmpVec = Src;
- SmallVector<int, 32> ShuffleMask(VF);
- for (unsigned i = VF; i != 1; i >>= 1) {
- // Move the upper half of the vector to the lower half.
- for (unsigned j = 0; j != i / 2; ++j)
- ShuffleMask[j] = i / 2 + j;
-
- // Fill the rest of the mask with undef.
- std::fill(&ShuffleMask[i / 2], ShuffleMask.end(), -1);
-
+ auto BuildShuffledOp = [&Builder, &Op,
+ &RdxKind](SmallVectorImpl<int> &ShuffleMask,
+ Value *&TmpVec) -> void {
Value *Shuf = Builder.CreateShuffleVector(TmpVec, ShuffleMask, "rdx.shuf");
-
if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
TmpVec = Builder.CreateBinOp((Instruction::BinaryOps)Op, TmpVec, Shuf,
"bin.rdx");
@@ -1111,6 +1105,30 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src,
"Invalid min/max");
TmpVec = createMinMaxOp(Builder, RdxKind, TmpVec, Shuf);
}
+ };
+
+ Value *TmpVec = Src;
+ if (TargetTransformInfo::ReductionShuffle::Pairwise == RS) {
+ SmallVector<int, 32> ShuffleMask(VF);
+ for (unsigned stride = 1; stride < VF; stride <<= 1) {
+ // Initialise the mask with undef.
+ std::fill(ShuffleMask.begin(), ShuffleMask.end(), -1);
+ for (unsigned j = 0; j < VF; j += stride << 1) {
+ ShuffleMask[j] = j + stride;
+ }
+ BuildShuffledOp(ShuffleMask, TmpVec);
+ }
+ } else {
+ SmallVector<int, 32> ShuffleMask(VF);
+ for (unsigned i = VF; i != 1; i >>= 1) {
+ // Move the upper half of the vector to the lower half.
+ for (unsigned j = 0; j != i / 2; ++j)
+ ShuffleMask[j] = i / 2 + j;
+
+ // Fill the rest of the mask with undef.
+ std::fill(&ShuffleMask[i / 2], ShuffleMask.end(), -1);
+ BuildShuffledOp(ShuffleMask, TmpVec);
+ }
}
// The result is in the first element of the vector.
return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0));