diff options
author | Yingwei Zheng <dtcxzyw2333@gmail.com> | 2024-03-05 22:34:04 +0800 |
---|---|---|
committer | Tom Stellard <tstellar@redhat.com> | 2024-03-12 22:05:53 -0700 |
commit | 3f8711fc5e01685f0a751ef296d16cf9a1f4fd4d (patch) | |
tree | 273a38764bdc88051b3ba85e9a9a5e7a27a2731b | |
parent | 9b9aee16d4dcf1b4af49988ebd7918fa4ce77e44 (diff) | |
download | llvm-3f8711fc5e01685f0a751ef296d16cf9a1f4fd4d.zip llvm-3f8711fc5e01685f0a751ef296d16cf9a1f4fd4d.tar.gz llvm-3f8711fc5e01685f0a751ef296d16cf9a1f4fd4d.tar.bz2 |
[InstCombine] Fix miscompilation in PR83947 (#83993)
https://github.com/llvm/llvm-project/blob/762f762504967efbe159db5c737154b989afc9bb/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp#L394-L407
Comment from @topperc:
> This transforms assumes the mask is a non-zero splat. We only know its
a splat and not provably all 0s. The mask is a constexpr that includes
the address of the global variable. We can't resolve the constant
expression to an exact value.
Fixes #83947.
-rw-r--r-- | llvm/include/llvm/Analysis/VectorUtils.h | 5 | ||||
-rw-r--r-- | llvm/lib/Analysis/VectorUtils.cpp | 25 | ||||
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | 13 | ||||
-rw-r--r-- | llvm/test/Transforms/InstCombine/masked_intrinsics.ll | 6 | ||||
-rw-r--r-- | llvm/test/Transforms/InstCombine/pr83947.ll | 67 |
5 files changed, 110 insertions, 6 deletions
diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h index 7a92e62..c6eb66c 100644 --- a/llvm/include/llvm/Analysis/VectorUtils.h +++ b/llvm/include/llvm/Analysis/VectorUtils.h @@ -406,6 +406,11 @@ bool maskIsAllZeroOrUndef(Value *Mask); /// lanes can be assumed active. bool maskIsAllOneOrUndef(Value *Mask); +/// Given a mask vector of i1, Return true if any of the elements of this +/// predicate mask are known to be true or undef. That is, return true if at +/// least one lane can be assumed active. +bool maskContainsAllOneOrUndef(Value *Mask); + /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y) /// for each lane which may be active. APInt possiblyDemandedEltsInMask(Value *Mask); diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp index 73facc7..bf7bc0b 100644 --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -1012,6 +1012,31 @@ bool llvm::maskIsAllOneOrUndef(Value *Mask) { return true; } +bool llvm::maskContainsAllOneOrUndef(Value *Mask) { + assert(isa<VectorType>(Mask->getType()) && + isa<IntegerType>(Mask->getType()->getScalarType()) && + cast<IntegerType>(Mask->getType()->getScalarType())->getBitWidth() == + 1 && + "Mask must be a vector of i1"); + + auto *ConstMask = dyn_cast<Constant>(Mask); + if (!ConstMask) + return false; + if (ConstMask->isAllOnesValue() || isa<UndefValue>(ConstMask)) + return true; + if (isa<ScalableVectorType>(ConstMask->getType())) + return false; + for (unsigned + I = 0, + E = cast<FixedVectorType>(ConstMask->getType())->getNumElements(); + I != E; ++I) { + if (auto *MaskElt = ConstMask->getAggregateElement(I)) + if (MaskElt->isAllOnesValue() || isa<UndefValue>(MaskElt)) + return true; + } + return false; +} + /// TODO: This is a lot like known bits, but for /// vectors. Is there something we can common this with? APInt llvm::possiblyDemandedEltsInMask(Value *Mask) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index a647be2..bc43edb 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -412,11 +412,14 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) { // scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) { - Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue(); - StoreInst *S = - new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment); - S->copyMetadata(II); - return S; + if (maskContainsAllOneOrUndef(ConstMask)) { + Align Alignment = + cast<ConstantInt>(II.getArgOperand(2))->getAlignValue(); + StoreInst *S = new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, + Alignment); + S->copyMetadata(II); + return S; + } } // scatter(vector, splat(ptr), splat(true)) -> store extract(vector, // lastlane), ptr diff --git a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll index 2704905..c87c119 100644 --- a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll +++ b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll @@ -292,7 +292,11 @@ entry: define void @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(ptr %dst, i16 %val) { ; CHECK-LABEL: @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask( ; CHECK-NEXT: entry: -; CHECK-NEXT: store i16 [[VAL:%.*]], ptr [[DST:%.*]], align 2 +; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <vscale x 4 x ptr> poison, ptr [[DST:%.*]], i64 0 +; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 4 x ptr> [[BROADCAST_SPLATINSERT]], <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer +; CHECK-NEXT: [[BROADCAST_VALUE:%.*]] = insertelement <vscale x 4 x i16> poison, i16 [[VAL:%.*]], i64 0 +; CHECK-NEXT: [[BROADCAST_SPLATVALUE:%.*]] = shufflevector <vscale x 4 x i16> [[BROADCAST_VALUE]], <vscale x 4 x i16> poison, <vscale x 4 x i32> zeroinitializer +; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i16.nxv4p0(<vscale x 4 x i16> [[BROADCAST_SPLATVALUE]], <vscale x 4 x ptr> [[BROADCAST_SPLAT]], i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> zeroinitializer, i1 true, i32 0), <vscale x 4 x i1> zeroinitializer, <vscale x 4 x i32> zeroinitializer)) ; CHECK-NEXT: ret void ; entry: diff --git a/llvm/test/Transforms/InstCombine/pr83947.ll b/llvm/test/Transforms/InstCombine/pr83947.ll new file mode 100644 index 0000000..c1d601f --- /dev/null +++ b/llvm/test/Transforms/InstCombine/pr83947.ll @@ -0,0 +1,67 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4 +; RUN: opt -S -passes=instcombine < %s | FileCheck %s + +@c = global i32 0, align 4 +@b = global i32 0, align 4 + +define void @masked_scatter1() { +; CHECK-LABEL: define void @masked_scatter1() { +; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> zeroinitializer, <vscale x 4 x ptr> shufflevector (<vscale x 4 x ptr> insertelement (<vscale x 4 x ptr> poison, ptr @c, i64 0), <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer), i32 4, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c), i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer)) +; CHECK-NEXT: ret void +; + call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> zeroinitializer, <vscale x 4 x ptr> splat (ptr @c), i32 4, <vscale x 4 x i1> splat (i1 icmp eq (ptr getelementptr (i32, ptr @b, i64 1), ptr @c))) + ret void +} + +define void @masked_scatter2() { +; CHECK-LABEL: define void @masked_scatter2() { +; CHECK-NEXT: store i32 0, ptr @c, align 4 +; CHECK-NEXT: ret void +; + call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 true)) + ret void +} + +define void @masked_scatter3() { +; CHECK-LABEL: define void @masked_scatter3() { +; CHECK-NEXT: store i32 0, ptr @c, align 4 +; CHECK-NEXT: ret void +; + call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> undef) + ret void +} + +define void @masked_scatter4() { +; CHECK-LABEL: define void @masked_scatter4() { +; CHECK-NEXT: ret void +; + call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 false)) + ret void +} + +define void @masked_scatter5() { +; CHECK-LABEL: define void @masked_scatter5() { +; CHECK-NEXT: store i32 0, ptr @c, align 4 +; CHECK-NEXT: ret void +; + call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> <i1 true, i1 false>) + ret void +} + +define void @masked_scatter6() { +; CHECK-LABEL: define void @masked_scatter6() { +; CHECK-NEXT: store i32 0, ptr @c, align 4 +; CHECK-NEXT: ret void +; + call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> <i1 undef, i1 false>) + ret void +} + +define void @masked_scatter7() { +; CHECK-LABEL: define void @masked_scatter7() { +; CHECK-NEXT: call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> <ptr @c, ptr @c>, i32 4, <2 x i1> <i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c), i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c)>) +; CHECK-NEXT: ret void +; + call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 icmp eq (ptr getelementptr (i32, ptr @b, i64 1), ptr @c))) + ret void +} |