diff options
author | Pete Steinfeld <47540744+psteinfeld@users.noreply.github.com> | 2024-06-27 14:54:02 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-27 14:54:02 -0700 |
commit | e55aa027f813679ca63c9b803690ce792a3d7b28 (patch) | |
tree | 881f6743391e455626b6c859b0657e355fc5f094 /flang | |
parent | 5b363483cf2461617fbb2449491c9914811c8d53 (diff) | |
download | llvm-e55aa027f813679ca63c9b803690ce792a3d7b28.zip llvm-e55aa027f813679ca63c9b803690ce792a3d7b28.tar.gz llvm-e55aa027f813679ca63c9b803690ce792a3d7b28.tar.bz2 |
[flang] Fix runtime error messages for the MATMUL intrinsic (#96928)
There are three forms of MATMUL -- where the first argument is a rank 1
array, where the second argument is a rank 1 array, and where both
arguments are rank 2 arrays. There's code in the runtime that detects
when the array shapes are incorrect. But the code that emits an error
message assumes that both arguments are rank 2 arrays.
This change contains code for the other two cases.
Diffstat (limited to 'flang')
-rw-r--r-- | flang/runtime/matmul.cpp | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/flang/runtime/matmul.cpp b/flang/runtime/matmul.cpp index 543284c..8f9b50a 100644 --- a/flang/runtime/matmul.cpp +++ b/flang/runtime/matmul.cpp @@ -288,11 +288,25 @@ static inline RT_API_ATTRS void DoMatmul( } SubscriptValue n{x.GetDimension(xRank - 1).Extent()}; if (n != y.GetDimension(0).Extent()) { - terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)", - static_cast<std::intmax_t>(x.GetDimension(0).Extent()), - static_cast<std::intmax_t>(n), - static_cast<std::intmax_t>(y.GetDimension(0).Extent()), - static_cast<std::intmax_t>(y.GetDimension(1).Extent())); + // At this point, we know that there's a shape error. There are three + // possibilities, x is rank 1, y is rank 1, or both are rank 2. + if (xRank == 1) { + terminator.Crash("MATMUL: unacceptable operand shapes (%jd, %jdx%jd)", + static_cast<std::intmax_t>(n), + static_cast<std::intmax_t>(y.GetDimension(0).Extent()), + static_cast<std::intmax_t>(y.GetDimension(1).Extent())); + } else if (yRank == 1) { + terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jd)", + static_cast<std::intmax_t>(x.GetDimension(0).Extent()), + static_cast<std::intmax_t>(n), + static_cast<std::intmax_t>(y.GetDimension(0).Extent())); + } else { + terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)", + static_cast<std::intmax_t>(x.GetDimension(0).Extent()), + static_cast<std::intmax_t>(n), + static_cast<std::intmax_t>(y.GetDimension(0).Extent()), + static_cast<std::intmax_t>(y.GetDimension(1).Extent())); + } } using WriteResult = CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT, |