diff options
author | Sam Parker <sam.parker@arm.com> | 2024-07-17 09:21:52 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-17 09:21:52 +0100 |
commit | d28ed29d6bd9f0389092775406fff7e6205d4d5f (patch) | |
tree | be650c3abc641b4da02f1fcdb94c4a125881a691 /llvm/lib/Transforms/Utils/LoopUtils.cpp | |
parent | caaf8099efa87a7ebca8920971b7d7f719808591 (diff) | |
download | llvm-d28ed29d6bd9f0389092775406fff7e6205d4d5f.zip llvm-d28ed29d6bd9f0389092775406fff7e6205d4d5f.tar.gz llvm-d28ed29d6bd9f0389092775406fff7e6205d4d5f.tar.bz2 |
[TTI][WebAssembly] Pairwise reduction expansion (#93948)
WebAssembly doesn't support horizontal operations nor does it have a way
of expressing fast-math or reassoc flags, so runtimes are currently
unable to use pairwise operations when generating code from the existing
shuffle patterns.
This patch allows the backend to select which, arbitary, shuffle pattern
to be used per reduction intrinsic. The default behaviour is the same as
the existing, which is by splitting the vector into a top and bottom
half. The other pattern introduced is for a pairwise shuffle.
WebAssembly enables pairwise reductions for int/fp add/sub.
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/LoopUtils.cpp | 42 |
1 files changed, 30 insertions, 12 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index ff93035..4609376 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -1077,7 +1077,9 @@ Value *llvm::getOrderedReduction(IRBuilderBase &Builder, Value *Acc, Value *Src, // Helper to generate a log2 shuffle reduction. Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src, - unsigned Op, RecurKind RdxKind) { + unsigned Op, + TargetTransformInfo::ReductionShuffle RS, + RecurKind RdxKind) { unsigned VF = cast<FixedVectorType>(Src->getType())->getNumElements(); // VF is a power of 2 so we can emit the reduction using log2(VF) shuffles // and vector ops, reducing the set of values being computed by half each @@ -1091,18 +1093,10 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src, // will never be relevant here. Note that it would be generally unsound to // propagate these from an intrinsic call to the expansion anyways as we/ // change the order of operations. - Value *TmpVec = Src; - SmallVector<int, 32> ShuffleMask(VF); - for (unsigned i = VF; i != 1; i >>= 1) { - // Move the upper half of the vector to the lower half. - for (unsigned j = 0; j != i / 2; ++j) - ShuffleMask[j] = i / 2 + j; - - // Fill the rest of the mask with undef. - std::fill(&ShuffleMask[i / 2], ShuffleMask.end(), -1); - + auto BuildShuffledOp = [&Builder, &Op, + &RdxKind](SmallVectorImpl<int> &ShuffleMask, + Value *&TmpVec) -> void { Value *Shuf = Builder.CreateShuffleVector(TmpVec, ShuffleMask, "rdx.shuf"); - if (Op != Instruction::ICmp && Op != Instruction::FCmp) { TmpVec = Builder.CreateBinOp((Instruction::BinaryOps)Op, TmpVec, Shuf, "bin.rdx"); @@ -1111,6 +1105,30 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src, "Invalid min/max"); TmpVec = createMinMaxOp(Builder, RdxKind, TmpVec, Shuf); } + }; + + Value *TmpVec = Src; + if (TargetTransformInfo::ReductionShuffle::Pairwise == RS) { + SmallVector<int, 32> ShuffleMask(VF); + for (unsigned stride = 1; stride < VF; stride <<= 1) { + // Initialise the mask with undef. + std::fill(ShuffleMask.begin(), ShuffleMask.end(), -1); + for (unsigned j = 0; j < VF; j += stride << 1) { + ShuffleMask[j] = j + stride; + } + BuildShuffledOp(ShuffleMask, TmpVec); + } + } else { + SmallVector<int, 32> ShuffleMask(VF); + for (unsigned i = VF; i != 1; i >>= 1) { + // Move the upper half of the vector to the lower half. + for (unsigned j = 0; j != i / 2; ++j) + ShuffleMask[j] = i / 2 + j; + + // Fill the rest of the mask with undef. + std::fill(&ShuffleMask[i / 2], ShuffleMask.end(), -1); + BuildShuffledOp(ShuffleMask, TmpVec); + } } // The result is in the first element of the vector. return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0)); |