diff options
author | Peter Klausler <35819229+klausler@users.noreply.github.com> | 2024-06-24 10:46:30 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-24 10:46:30 -0700 |
commit | 514c1ec5477a48e4f639c0b15ab757832b67dd10 (patch) | |
tree | 7503ab8cfef3687c310a3d835e97417657f3089a | |
parent | eac925fb81f26342811ad1765e8f9919628e2254 (diff) | |
download | llvm-514c1ec5477a48e4f639c0b15ab757832b67dd10.zip llvm-514c1ec5477a48e4f639c0b15ab757832b67dd10.tar.gz llvm-514c1ec5477a48e4f639c0b15ab757832b67dd10.tar.bz2 |
[flang][runtime] Interoperable POINTER deallocation validation (#96100)
Extend the runtime validation of deallocated pointers so that it also
works when pointers are allocated &/or deallocated outside Fortran.
Previously, bogus runtime errors would be reported for pointers
allocated via CFI_allocate() and deallocated in Fortran, and
CFI_deallocate() did not check that it was deallocating a whole
contiguous pointer that was allocated as such.
-rw-r--r-- | flang/include/flang/Runtime/pointer.h | 5 | ||||
-rw-r--r-- | flang/runtime/ISO_Fortran_binding.cpp | 10 | ||||
-rw-r--r-- | flang/runtime/descriptor.cpp | 11 | ||||
-rw-r--r-- | flang/runtime/pointer.cpp | 67 |
4 files changed, 64 insertions, 29 deletions
diff --git a/flang/include/flang/Runtime/pointer.h b/flang/include/flang/Runtime/pointer.h index 6ceb70e..704144f 100644 --- a/flang/include/flang/Runtime/pointer.h +++ b/flang/include/flang/Runtime/pointer.h @@ -115,6 +115,11 @@ bool RTDECL(PointerIsAssociated)(const Descriptor &); bool RTDECL(PointerIsAssociatedWith)( const Descriptor &, const Descriptor *target); +// Fortran POINTERs are allocated with an extra validation word after their +// payloads in order to detect erroneous deallocations later. +RT_API_ATTRS void *AllocateValidatedPointerPayload(std::size_t); +RT_API_ATTRS bool ValidatePointerPayload(const ISO::CFI_cdesc_t &); + } // extern "C" } // namespace Fortran::runtime #endif // FORTRAN_RUNTIME_POINTER_H_ diff --git a/flang/runtime/ISO_Fortran_binding.cpp b/flang/runtime/ISO_Fortran_binding.cpp index 99ba3aa..fe22026 100644 --- a/flang/runtime/ISO_Fortran_binding.cpp +++ b/flang/runtime/ISO_Fortran_binding.cpp @@ -13,6 +13,7 @@ #include "terminator.h" #include "flang/ISO_Fortran_binding_wrapper.h" #include "flang/Runtime/descriptor.h" +#include "flang/Runtime/pointer.h" #include "flang/Runtime/type-code.h" #include <cstdlib> @@ -75,7 +76,7 @@ RT_API_ATTRS int CFI_allocate(CFI_cdesc_t *descriptor, dim->sm = byteSize; byteSize *= extent; } - void *p{byteSize ? std::malloc(byteSize) : std::malloc(1)}; + void *p{runtime::AllocateValidatedPointerPayload(byteSize)}; if (!p && byteSize) { return CFI_ERROR_MEM_ALLOCATION; } @@ -91,8 +92,11 @@ RT_API_ATTRS int CFI_deallocate(CFI_cdesc_t *descriptor) { if (descriptor->version != CFI_VERSION) { return CFI_INVALID_DESCRIPTOR; } - if (descriptor->attribute != CFI_attribute_allocatable && - descriptor->attribute != CFI_attribute_pointer) { + if (descriptor->attribute == CFI_attribute_pointer) { + if (!runtime::ValidatePointerPayload(*descriptor)) { + return CFI_INVALID_DESCRIPTOR; + } + } else if (descriptor->attribute != CFI_attribute_allocatable) { // Non-interoperable object return CFI_INVALID_DESCRIPTOR; } diff --git a/flang/runtime/descriptor.cpp b/flang/runtime/descriptor.cpp index d8b51f1..9b04cb4 100644 --- a/flang/runtime/descriptor.cpp +++ b/flang/runtime/descriptor.cpp @@ -199,7 +199,16 @@ RT_API_ATTRS int Descriptor::Destroy( } } -RT_API_ATTRS int Descriptor::Deallocate() { return ISO::CFI_deallocate(&raw_); } +RT_API_ATTRS int Descriptor::Deallocate() { + ISO::CFI_cdesc_t &descriptor{raw()}; + if (!descriptor.base_addr) { + return CFI_ERROR_BASE_ADDR_NULL; + } else { + std::free(descriptor.base_addr); + descriptor.base_addr = nullptr; + return CFI_SUCCESS; + } +} RT_API_ATTRS bool Descriptor::DecrementSubscripts( SubscriptValue *subscript, const int *permutation) const { diff --git a/flang/runtime/pointer.cpp b/flang/runtime/pointer.cpp index 08a1223..aeed879 100644 --- a/flang/runtime/pointer.cpp +++ b/flang/runtime/pointer.cpp @@ -124,6 +124,23 @@ void RTDEF(PointerAssociateRemapping)(Descriptor &pointer, } } +RT_API_ATTRS void *AllocateValidatedPointerPayload(std::size_t byteSize) { + // Add space for a footer to validate during deallocation. + constexpr std::size_t align{sizeof(std::uintptr_t)}; + byteSize = ((byteSize / align) + 1) * align; + std::size_t total{byteSize + sizeof(std::uintptr_t)}; + void *p{std::malloc(total)}; + if (p) { + // Fill the footer word with the XOR of the ones' complement of + // the base address, which is a value that would be highly unlikely + // to appear accidentally at the right spot. + std::uintptr_t *footer{ + reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)}; + *footer = ~reinterpret_cast<std::uintptr_t>(p); + } + return p; +} + int RTDEF(PointerAllocate)(Descriptor &pointer, bool hasStat, const Descriptor *errMsg, const char *sourceFile, int sourceLine) { Terminator terminator{sourceFile, sourceLine}; @@ -137,22 +154,12 @@ int RTDEF(PointerAllocate)(Descriptor &pointer, bool hasStat, elementBytes = pointer.raw().elem_len = 0; } std::size_t byteSize{pointer.Elements() * elementBytes}; - // Add space for a footer to validate during DEALLOCATE. - constexpr std::size_t align{sizeof(std::uintptr_t)}; - byteSize = ((byteSize + align - 1) / align) * align; - std::size_t total{byteSize + sizeof(std::uintptr_t)}; - void *p{std::malloc(total)}; + void *p{AllocateValidatedPointerPayload(byteSize)}; if (!p) { return ReturnError(terminator, CFI_ERROR_MEM_ALLOCATION, errMsg, hasStat); } pointer.set_base_addr(p); pointer.SetByteStrides(); - // Fill the footer word with the XOR of the ones' complement of - // the base address, which is a value that would be highly unlikely - // to appear accidentally at the right spot. - std::uintptr_t *footer{ - reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)}; - *footer = ~reinterpret_cast<std::uintptr_t>(p); int stat{StatOk}; if (const DescriptorAddendum * addendum{pointer.Addendum()}) { if (const auto *derived{addendum->derivedType()}) { @@ -176,6 +183,27 @@ int RTDEF(PointerAllocateSource)(Descriptor &pointer, const Descriptor &source, return stat; } +static RT_API_ATTRS std::size_t GetByteSize( + const ISO::CFI_cdesc_t &descriptor) { + std::size_t rank{descriptor.rank}; + const ISO::CFI_dim_t *dim{descriptor.dim}; + std::size_t byteSize{descriptor.elem_len}; + for (std::size_t j{0}; j < rank; ++j) { + byteSize *= dim[j].extent; + } + return byteSize; +} + +bool RT_API_ATTRS ValidatePointerPayload(const ISO::CFI_cdesc_t &desc) { + std::size_t byteSize{GetByteSize(desc)}; + constexpr std::size_t align{sizeof(std::uintptr_t)}; + byteSize = ((byteSize / align) + 1) * align; + const void *p{desc.base_addr}; + const std::uintptr_t *footer{reinterpret_cast<const std::uintptr_t *>( + static_cast<const char *>(p) + byteSize)}; + return *footer == ~reinterpret_cast<std::uintptr_t>(p); +} + int RTDEF(PointerDeallocate)(Descriptor &pointer, bool hasStat, const Descriptor *errMsg, const char *sourceFile, int sourceLine) { Terminator terminator{sourceFile, sourceLine}; @@ -185,20 +213,9 @@ int RTDEF(PointerDeallocate)(Descriptor &pointer, bool hasStat, if (!pointer.IsAllocated()) { return ReturnError(terminator, StatBaseNull, errMsg, hasStat); } - if (executionEnvironment.checkPointerDeallocation) { - // Validate the footer. This should fail if the pointer doesn't - // span the entire object, or the object was not allocated as a - // pointer. - std::size_t byteSize{pointer.Elements() * pointer.ElementBytes()}; - constexpr std::size_t align{sizeof(std::uintptr_t)}; - byteSize = ((byteSize + align - 1) / align) * align; - void *p{pointer.raw().base_addr}; - std::uintptr_t *footer{ - reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)}; - if (*footer != ~reinterpret_cast<std::uintptr_t>(p)) { - return ReturnError( - terminator, StatBadPointerDeallocation, errMsg, hasStat); - } + if (executionEnvironment.checkPointerDeallocation && + !ValidatePointerPayload(pointer.raw())) { + return ReturnError(terminator, StatBadPointerDeallocation, errMsg, hasStat); } return ReturnError(terminator, pointer.Destroy(/*finalize=*/true, /*destroyPointers=*/true, &terminator), |