diff options
author | Alexandre Oliva <aoliva@gcc.gnu.org> | 2019-08-09 09:20:58 +0000 |
---|---|---|
committer | Alexandre Oliva <aoliva@gcc.gnu.org> | 2019-08-09 09:20:58 +0000 |
commit | c787deb0124b667802d8519bc285894bb6d771d7 (patch) | |
tree | dcf471c17642d638b25a9502daa1265fac53201b | |
parent | 279dc7a3624ff68e9bb4f44293877250a8097c14 (diff) | |
download | gcc-c787deb0124b667802d8519bc285894bb6d771d7.zip gcc-c787deb0124b667802d8519bc285894bb6d771d7.tar.gz gcc-c787deb0124b667802d8519bc285894bb6d771d7.tar.bz2 |
skip Cholesky decomposition in is>>n_mv_dist
normal_mv_distribution maintains the variance-covariance matrix param
in Cholesky-decomposed form. Existing param_type constructors, when
taking a full or lower-triangle varcov matrix, perform Cholesky
decomposition to convert it to the internal representation. This
internal representation is visible both in the varcov() result, and in
the streamed-out representation of a normal_mv_distribution object.
The problem is that when that representation is streamed back in, the
read-back decomposed varcov matrix is used as a lower-triangle
non-decomposed varcov matrix, and it undergoes Cholesky decomposition
again. So, each cycle of stream-out/stream-in changes the varcov
matrix to its "square root", instead of restoring the original
params.
This patch includes Corentin's changes that introduce verification in
testsuite/ext/random/normal_mv_distribution/operators/serialize.cc and
other similar tests that the object read back in compares equal to the
written-out object: the modified tests pass only if (u == v).
This patch also fixes the error exposed by his change, introducing an
alternate private constructor for param_type, used only by operator>>.
for libstdc++-v3/ChangeLog
* include/ext/random
(normal_mv_distribution::param_type::param_type): New private
ctor taking a decomposed varcov matrix, for use by...
(operator>>): ... this, befriended.
* include/ext/random.tcc (operator>>): Use it.
(normal_mv_distribution::param_type::_M_init_lower): Adjust
member function name in exception message.
for libstdc++-v3/ChangeLog
from Corentin Gay <gay@adacore.com>
* testsuite/ext/random/beta_distribution/operators/serialize.cc,
testsuite/ext/random/hypergeometric_distribution/operators/serialize.cc,
testsuite/ext/random/normal_mv_distribution/operators/serialize.cc,
testsuite/ext/random/triangular_distribution/operators/serialize.cc,
testsuite/ext/random/von_mises_distribution/operators/serialize.cc:
Add call to `VERIFY`.
From-SVN: r274233
8 files changed, 48 insertions, 3 deletions
diff --git a/libstdc++-v3/ChangeLog b/libstdc++-v3/ChangeLog index 29418eb..5c02cb3 100644 --- a/libstdc++-v3/ChangeLog +++ b/libstdc++-v3/ChangeLog @@ -1,3 +1,22 @@ +2019-08-09 Corentin Gay <gay@adacore.com> + + * testsuite/ext/random/beta_distribution/operators/serialize.cc, + testsuite/ext/random/hypergeometric_distribution/operators/serialize.cc, + testsuite/ext/random/normal_mv_distribution/operators/serialize.cc, + testsuite/ext/random/triangular_distribution/operators/serialize.cc, + testsuite/ext/random/von_mises_distribution/operators/serialize.cc: + Add call to `VERIFY`. + +2019-08-09 Alexandre Oliva <oliva@adacore.com> + + * include/ext/random + (normal_mv_distribution::param_type::param_type): New private + ctor taking a decomposed varcov matrix, for use by... + (operator>>): ... this, befriended. + * include/ext/random.tcc (operator>>): Use it. + (normal_mv_distribution::param_type::_M_init_lower): Adjust + member function name in exception message. + 2019-08-08 Jonathan Wakely <jwakely@redhat.com> P0325R4 to_array from LFTS with updates diff --git a/libstdc++-v3/include/ext/random b/libstdc++-v3/include/ext/random index 41a2962..d5574e0 100644 --- a/libstdc++-v3/include/ext/random +++ b/libstdc++-v3/include/ext/random @@ -752,6 +752,21 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION _InputIterator2 __varbegin, _InputIterator2 __varend); + // param_type constructors apply Cholesky decomposition to the + // varcov matrix in _M_init_full and _M_init_lower, but the + // varcov matrix output ot a stream is already decomposed, so + // we need means to restore it as-is when reading it back in. + template<size_t _Dimen1, typename _RealType1, + typename _CharT, typename _Traits> + friend std::basic_istream<_CharT, _Traits>& + operator>>(std::basic_istream<_CharT, _Traits>& __is, + __gnu_cxx::normal_mv_distribution<_Dimen1, _RealType1>& + __x); + param_type(std::array<_RealType, _Dimen> const &__mean, + std::array<_RealType, _M_t_size> const &__varcov) + : _M_mean (__mean), _M_t (__varcov) + {} + std::array<_RealType, _Dimen> _M_mean; std::array<_RealType, _M_t_size> _M_t; }; diff --git a/libstdc++-v3/include/ext/random.tcc b/libstdc++-v3/include/ext/random.tcc index 31dc33a..a8a49a3 100644 --- a/libstdc++-v3/include/ext/random.tcc +++ b/libstdc++-v3/include/ext/random.tcc @@ -581,7 +581,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION __sum = *__varcovbegin++ - __sum; if (__builtin_expect(__sum <= _RealType(0), 0)) std::__throw_runtime_error(__N("normal_mv_distribution::" - "param_type::_M_init_full")); + "param_type::_M_init_lower")); *__w++ = std::sqrt(__sum); } } @@ -709,9 +709,11 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION __is >> __x._M_nd; + // The param_type temporary is built with a private constructor, + // to skip the Cholesky decomposition that would be performed + // otherwise. __x.param(typename normal_mv_distribution<_Dimen, _RealType>:: - param_type(__mean.begin(), __mean.end(), - __varcov.begin(), __varcov.end())); + param_type(__mean, __varcov)); __is.flags(__flags); return __is; diff --git a/libstdc++-v3/testsuite/ext/random/beta_distribution/operators/serialize.cc b/libstdc++-v3/testsuite/ext/random/beta_distribution/operators/serialize.cc index b054171..a4925fc 100644 --- a/libstdc++-v3/testsuite/ext/random/beta_distribution/operators/serialize.cc +++ b/libstdc++-v3/testsuite/ext/random/beta_distribution/operators/serialize.cc @@ -23,6 +23,7 @@ #include <ext/random> #include <sstream> +#include <testsuite_hooks.h> void test01() @@ -35,6 +36,7 @@ test01() str << u; str >> v; + VERIFY( u == v ); } int main() diff --git a/libstdc++-v3/testsuite/ext/random/hypergeometric_distribution/operators/serialize.cc b/libstdc++-v3/testsuite/ext/random/hypergeometric_distribution/operators/serialize.cc index 9c2cc46a..e9077b2 100644 --- a/libstdc++-v3/testsuite/ext/random/hypergeometric_distribution/operators/serialize.cc +++ b/libstdc++-v3/testsuite/ext/random/hypergeometric_distribution/operators/serialize.cc @@ -38,6 +38,7 @@ test01() str << u; str >> v; + VERIFY( u == v ); } int diff --git a/libstdc++-v3/testsuite/ext/random/normal_mv_distribution/operators/serialize.cc b/libstdc++-v3/testsuite/ext/random/normal_mv_distribution/operators/serialize.cc index 8d83f9e..f5fbc42a 100644 --- a/libstdc++-v3/testsuite/ext/random/normal_mv_distribution/operators/serialize.cc +++ b/libstdc++-v3/testsuite/ext/random/normal_mv_distribution/operators/serialize.cc @@ -23,6 +23,7 @@ #include <ext/random> #include <sstream> +#include <testsuite_hooks.h> void test01() @@ -35,6 +36,7 @@ test01() str << u; str >> v; + VERIFY( u == v ); } int main() diff --git a/libstdc++-v3/testsuite/ext/random/triangular_distribution/operators/serialize.cc b/libstdc++-v3/testsuite/ext/random/triangular_distribution/operators/serialize.cc index cf17fea..75e16cf 100644 --- a/libstdc++-v3/testsuite/ext/random/triangular_distribution/operators/serialize.cc +++ b/libstdc++-v3/testsuite/ext/random/triangular_distribution/operators/serialize.cc @@ -23,6 +23,7 @@ #include <ext/random> #include <sstream> +#include <testsuite_hooks.h> void test01() @@ -35,6 +36,7 @@ test01() str << u; str >> v; + VERIFY( u == v ); } int main() diff --git a/libstdc++-v3/testsuite/ext/random/von_mises_distribution/operators/serialize.cc b/libstdc++-v3/testsuite/ext/random/von_mises_distribution/operators/serialize.cc index f3d7912..b32a31d 100644 --- a/libstdc++-v3/testsuite/ext/random/von_mises_distribution/operators/serialize.cc +++ b/libstdc++-v3/testsuite/ext/random/von_mises_distribution/operators/serialize.cc @@ -23,6 +23,7 @@ #include <ext/random> #include <sstream> +#include <testsuite_hooks.h> void test01() @@ -35,6 +36,7 @@ test01() str << u; str >> v; + VERIFY( u == v ); } int main() |