aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/NVPTX
diff options
context:
space:
mode:
authorPetr <piter.zh@gmail.com>2024-02-12 13:50:00 +0100
committerGitHub <noreply@github.com>2024-02-12 13:50:00 +0100
commit45260bf23b802047ab4fd888b8bf2b32e4c5eb69 (patch)
tree91c888699e170bcb68e28f9ba5922e5c2d540f76 /llvm/lib/Target/NVPTX
parent0940f9083e68bda78bcbb323c2968a4294092e21 (diff)
downloadllvm-45260bf23b802047ab4fd888b8bf2b32e4c5eb69.zip
llvm-45260bf23b802047ab4fd888b8bf2b32e4c5eb69.tar.gz
llvm-45260bf23b802047ab4fd888b8bf2b32e4c5eb69.tar.bz2
Fix use after free error in NVVMReflect (#81471)
I have a Triton kernel, which triggered a heap-use-after-free error in LLVM. The problem was that the same instruction may be added to the `ToSimplify` array multiple times. If this duplicate instruction is trivially dead, it gets deleted on the first pass. Then, on the second pass, the freed instruction is passed. To fix this, I'm adding the instructions to the `ToRemove` array and filter it out for duplicates to avoid possible double frees.
Diffstat (limited to 'llvm/lib/Target/NVPTX')
-rw-r--r--llvm/lib/Target/NVPTX/NVVMReflect.cpp18
1 files changed, 13 insertions, 5 deletions
diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 64fedf3..29c95e4 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -39,6 +39,7 @@
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
+#include <algorithm>
#include <sstream>
#include <string>
#define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
@@ -185,9 +186,6 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
ToRemove.push_back(Call);
}
- for (Instruction *I : ToRemove)
- I->eraseFromParent();
-
// The code guarded by __nvvm_reflect may be invalid for the target machine.
// Traverse the use-def chain, continually simplifying constant expressions
// until we find a terminator that we can then remove.
@@ -200,13 +198,23 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
ToSimplify.push_back(I);
I->replaceAllUsesWith(C);
- if (isInstructionTriviallyDead(I))
- I->eraseFromParent();
+ if (isInstructionTriviallyDead(I)) {
+ ToRemove.push_back(I);
+ }
} else if (I->isTerminator()) {
ConstantFoldTerminator(I->getParent());
}
}
+ // Removing via isInstructionTriviallyDead may add duplicates to the ToRemove
+ // array. Filter out the duplicates before starting to erase from parent.
+ std::sort(ToRemove.begin(), ToRemove.end());
+ auto NewLastIter = std::unique(ToRemove.begin(), ToRemove.end());
+ ToRemove.erase(NewLastIter, ToRemove.end());
+
+ for (Instruction *I : ToRemove)
+ I->eraseFromParent();
+
return ToRemove.size() > 0;
}