aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYingwei Zheng <dtcxzyw2333@gmail.com>2024-03-05 22:34:04 +0800
committerTom Stellard <tstellar@redhat.com>2024-03-12 22:05:53 -0700
commit3f8711fc5e01685f0a751ef296d16cf9a1f4fd4d (patch)
tree273a38764bdc88051b3ba85e9a9a5e7a27a2731b
parent9b9aee16d4dcf1b4af49988ebd7918fa4ce77e44 (diff)
downloadllvm-3f8711fc5e01685f0a751ef296d16cf9a1f4fd4d.zip
llvm-3f8711fc5e01685f0a751ef296d16cf9a1f4fd4d.tar.gz
llvm-3f8711fc5e01685f0a751ef296d16cf9a1f4fd4d.tar.bz2
[InstCombine] Fix miscompilation in PR83947 (#83993)
https://github.com/llvm/llvm-project/blob/762f762504967efbe159db5c737154b989afc9bb/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp#L394-L407 Comment from @topperc: > This transforms assumes the mask is a non-zero splat. We only know its a splat and not provably all 0s. The mask is a constexpr that includes the address of the global variable. We can't resolve the constant expression to an exact value. Fixes #83947.
-rw-r--r--llvm/include/llvm/Analysis/VectorUtils.h5
-rw-r--r--llvm/lib/Analysis/VectorUtils.cpp25
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp13
-rw-r--r--llvm/test/Transforms/InstCombine/masked_intrinsics.ll6
-rw-r--r--llvm/test/Transforms/InstCombine/pr83947.ll67
5 files changed, 110 insertions, 6 deletions
diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 7a92e62..c6eb66c 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -406,6 +406,11 @@ bool maskIsAllZeroOrUndef(Value *Mask);
/// lanes can be assumed active.
bool maskIsAllOneOrUndef(Value *Mask);
+/// Given a mask vector of i1, Return true if any of the elements of this
+/// predicate mask are known to be true or undef. That is, return true if at
+/// least one lane can be assumed active.
+bool maskContainsAllOneOrUndef(Value *Mask);
+
/// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
/// for each lane which may be active.
APInt possiblyDemandedEltsInMask(Value *Mask);
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 73facc7..bf7bc0b 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -1012,6 +1012,31 @@ bool llvm::maskIsAllOneOrUndef(Value *Mask) {
return true;
}
+bool llvm::maskContainsAllOneOrUndef(Value *Mask) {
+ assert(isa<VectorType>(Mask->getType()) &&
+ isa<IntegerType>(Mask->getType()->getScalarType()) &&
+ cast<IntegerType>(Mask->getType()->getScalarType())->getBitWidth() ==
+ 1 &&
+ "Mask must be a vector of i1");
+
+ auto *ConstMask = dyn_cast<Constant>(Mask);
+ if (!ConstMask)
+ return false;
+ if (ConstMask->isAllOnesValue() || isa<UndefValue>(ConstMask))
+ return true;
+ if (isa<ScalableVectorType>(ConstMask->getType()))
+ return false;
+ for (unsigned
+ I = 0,
+ E = cast<FixedVectorType>(ConstMask->getType())->getNumElements();
+ I != E; ++I) {
+ if (auto *MaskElt = ConstMask->getAggregateElement(I))
+ if (MaskElt->isAllOnesValue() || isa<UndefValue>(MaskElt))
+ return true;
+ }
+ return false;
+}
+
/// TODO: This is a lot like known bits, but for
/// vectors. Is there something we can common this with?
APInt llvm::possiblyDemandedEltsInMask(Value *Mask) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index a647be2..bc43edb 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -412,11 +412,14 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) {
// scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
- Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
- StoreInst *S =
- new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment);
- S->copyMetadata(II);
- return S;
+ if (maskContainsAllOneOrUndef(ConstMask)) {
+ Align Alignment =
+ cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
+ StoreInst *S = new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false,
+ Alignment);
+ S->copyMetadata(II);
+ return S;
+ }
}
// scatter(vector, splat(ptr), splat(true)) -> store extract(vector,
// lastlane), ptr
diff --git a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
index 2704905..c87c119 100644
--- a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
+++ b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
@@ -292,7 +292,11 @@ entry:
define void @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(ptr %dst, i16 %val) {
; CHECK-LABEL: @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(
; CHECK-NEXT: entry:
-; CHECK-NEXT: store i16 [[VAL:%.*]], ptr [[DST:%.*]], align 2
+; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <vscale x 4 x ptr> poison, ptr [[DST:%.*]], i64 0
+; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 4 x ptr> [[BROADCAST_SPLATINSERT]], <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT: [[BROADCAST_VALUE:%.*]] = insertelement <vscale x 4 x i16> poison, i16 [[VAL:%.*]], i64 0
+; CHECK-NEXT: [[BROADCAST_SPLATVALUE:%.*]] = shufflevector <vscale x 4 x i16> [[BROADCAST_VALUE]], <vscale x 4 x i16> poison, <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i16.nxv4p0(<vscale x 4 x i16> [[BROADCAST_SPLATVALUE]], <vscale x 4 x ptr> [[BROADCAST_SPLAT]], i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> zeroinitializer, i1 true, i32 0), <vscale x 4 x i1> zeroinitializer, <vscale x 4 x i32> zeroinitializer))
; CHECK-NEXT: ret void
;
entry:
diff --git a/llvm/test/Transforms/InstCombine/pr83947.ll b/llvm/test/Transforms/InstCombine/pr83947.ll
new file mode 100644
index 0000000..c1d601f
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/pr83947.ll
@@ -0,0 +1,67 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+@c = global i32 0, align 4
+@b = global i32 0, align 4
+
+define void @masked_scatter1() {
+; CHECK-LABEL: define void @masked_scatter1() {
+; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> zeroinitializer, <vscale x 4 x ptr> shufflevector (<vscale x 4 x ptr> insertelement (<vscale x 4 x ptr> poison, ptr @c, i64 0), <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer), i32 4, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c), i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
+; CHECK-NEXT: ret void
+;
+ call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> zeroinitializer, <vscale x 4 x ptr> splat (ptr @c), i32 4, <vscale x 4 x i1> splat (i1 icmp eq (ptr getelementptr (i32, ptr @b, i64 1), ptr @c)))
+ ret void
+}
+
+define void @masked_scatter2() {
+; CHECK-LABEL: define void @masked_scatter2() {
+; CHECK-NEXT: store i32 0, ptr @c, align 4
+; CHECK-NEXT: ret void
+;
+ call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 true))
+ ret void
+}
+
+define void @masked_scatter3() {
+; CHECK-LABEL: define void @masked_scatter3() {
+; CHECK-NEXT: store i32 0, ptr @c, align 4
+; CHECK-NEXT: ret void
+;
+ call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> undef)
+ ret void
+}
+
+define void @masked_scatter4() {
+; CHECK-LABEL: define void @masked_scatter4() {
+; CHECK-NEXT: ret void
+;
+ call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 false))
+ ret void
+}
+
+define void @masked_scatter5() {
+; CHECK-LABEL: define void @masked_scatter5() {
+; CHECK-NEXT: store i32 0, ptr @c, align 4
+; CHECK-NEXT: ret void
+;
+ call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> <i1 true, i1 false>)
+ ret void
+}
+
+define void @masked_scatter6() {
+; CHECK-LABEL: define void @masked_scatter6() {
+; CHECK-NEXT: store i32 0, ptr @c, align 4
+; CHECK-NEXT: ret void
+;
+ call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> <i1 undef, i1 false>)
+ ret void
+}
+
+define void @masked_scatter7() {
+; CHECK-LABEL: define void @masked_scatter7() {
+; CHECK-NEXT: call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> <ptr @c, ptr @c>, i32 4, <2 x i1> <i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c), i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c)>)
+; CHECK-NEXT: ret void
+;
+ call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 icmp eq (ptr getelementptr (i32, ptr @b, i64 1), ptr @c)))
+ ret void
+}