diff --git a/expui/BiorthBasis.H b/expui/BiorthBasis.H index a4af27c1a..9b8c4e744 100644 --- a/expui/BiorthBasis.H +++ b/expui/BiorthBasis.H @@ -3,6 +3,7 @@ #include #include +#include #include #include // For 3d rectangular grids @@ -1033,9 +1034,9 @@ namespace BasisClasses std::tuple> - IntegrateOrbits (double tinit, double tfinal, double h, + IntegrateOrbits (double tinit, double tfinal, Eigen::MatrixXd ps, std::vector bfe, - AccelFunctor F, int nout=std::numeric_limits::max()); + AccelFunctor F, std::optional h=std::nullopt, std::optional nout=std::nullopt); using BiorthBasisPtr = std::shared_ptr; } diff --git a/expui/BiorthBasis.cc b/expui/BiorthBasis.cc index f69bb0706..442a71e2b 100644 --- a/expui/BiorthBasis.cc +++ b/expui/BiorthBasis.cc @@ -3147,9 +3147,9 @@ namespace BasisClasses std::tuple> IntegrateOrbits - (double tinit, double tfinal, double h, + (double tinit, double tfinal, Eigen::MatrixXd ps, std::vector bfe, AccelFunctor F, - int nout) + std::optional h, std::optional nout) { int rows = ps.rows(); int cols = ps.cols(); @@ -3169,37 +3169,70 @@ namespace BasisClasses // Sanity check // - if ( (tfinal - tinit)/h > + if (h.has_value()){ + if (*h <= 0.){ + std::ostringstream sout; + sout << "BasicFactor::IntegrateOrbits: unreasonal input"; + throw std::runtime_error(sout.str()); + } + } + if (nout.has_value()){ + if (*nout < 2){ + std::ostringstream sout; + sout << "BasicFactor::IntegrateOrbits: unreasonal input"; + throw std::runtime_error(sout.str()); + } + } + if ( (tfinal - tinit)/ *h > static_cast(std::numeric_limits::max()) ) { - std::cout << "BasicFactor::IntegrateOrbits: step size is too small or " + std::ostringstream sout; + sout << "BasicFactor::IntegrateOrbits: step size is too small or " << "time interval is too large.\n"; - // Return empty data - // - return {Eigen::VectorXd(), Eigen::Tensor()}; + throw std::runtime_error(sout.str()); } + if (tinit < 0 || tfinal < tinit){ + std::ostringstream sout; + sout << "BasicFactor::IntegrateOrbits: unreasonal input"; + throw std::runtime_error(sout.str()); + } - // Number of steps - // - int numT = floor( (tfinal - tinit)/h ); - - // Compute output step - // - nout = std::min(numT, nout); - double H = (tfinal - tinit)/nout; + int numT = 0; + double H = 0.0; + if (h.has_value() && nout.has_value()){ + numT = floor( (tfinal - tinit)/ *h ) + 1; + if (std::abs(numT - *nout) > 0){ + std::ostringstream sout; + sout << "BasicFactor::IntegrateOrbits: inconsistent step number/size input"; + throw std::runtime_error(sout.str()); + } else { + H = *h; + } + } else if (h.has_value()) { + numT = floor( (tfinal - tinit)/ *h ) + 1; + H = *h; + } else if (nout.has_value()) { + numT = *nout; + H = (tfinal - tinit)/ (*nout - 1); + } else { + std::ostringstream sout; + sout << "BasicFactor::IntegrateOrbits: no step number/size input"; + throw std::runtime_error(sout.str()); + } + // Use numT and H for calculation // Return data // Eigen::Tensor ret; try { - ret.resize(rows, 6, nout); + ret.resize(rows, 6, numT); } catch (const std::bad_alloc& e) { std::cout << "BasicFactor::IntegrateOrbits: memory allocation failed: " << e.what() << std::endl << "Your requested number of orbits and time steps requires " - << floor(8.0*rows*6*nout/1e9)+1 << " GB free memory" + << floor(8.0*rows*6*numT/1e9)+1 << " GB free memory" << std::endl; // Return empty data @@ -3209,7 +3242,7 @@ namespace BasisClasses // Time array // - Eigen::VectorXd times(nout); + Eigen::VectorXd times(numT); // Do the work // @@ -3218,19 +3251,14 @@ namespace BasisClasses for (int k=0; k<6; k++) ret(n, k, 0) = ps(n, k); double tnow = tinit; - for (int s=1, cnt=1; s= H*cnt-h*1.0e-8) { - times(cnt) = tnow; - for (int n=0; n= H*cnt-H*1.0e-8) { + times(cnt) = tnow; + for (int n=0; n bfe, - BasisClasses::AccelFunc& func, int stride) + BasisClasses::AccelFunc& func, std::optional h, std::optional nout) { Eigen::VectorXd T; Eigen::Tensor O; @@ -2012,9 +2012,9 @@ void BasisFactoryClasses(py::module &m) AccelFunctor F = [&func](double t, Eigen::MatrixXd& ps, Eigen::MatrixXd& accel, BasisCoef mod)->Eigen::MatrixXd& { return func.F(t, ps, accel, mod);}; std::tie(T, O) = - BasisClasses::IntegrateOrbits(tinit, tfinal, h, ps, bfe, F, stride); + BasisClasses::IntegrateOrbits(tinit, tfinal, ps, bfe, F, h, nout); - py::array_t ret = make_ndarray3(O); + py::array_t ret = make_ndarray(O); return std::tuple>(T, ret); }, R"( @@ -2049,7 +2049,7 @@ void BasisFactoryClasses(py::module &m) tuple(numpy.array, numpy.ndarray) time and phase-space arrays )", - py::arg("tinit"), py::arg("tfinal"), py::arg("h"), + py::arg("tinit"), py::arg("tfinal"), py::arg("h")=std::nullopt, py::arg("ps"), py::arg("basiscoef"), py::arg("func"), - py::arg("nout")=std::numeric_limits::max()); + py::arg("nout")=std::nullopt); }