Skip to content

Conversation

@aphearin
Copy link
Contributor

This PR is still WIP. It brings in alternative versions of the functions calc_mgas_galpop and calc_mgas_singlegal with new versions based on jax.grad. Current, temporary names for these new functions are calc_mgas_galpop3 and calc_mgas_singlegal3, defined in the same module, diffstar.mgas_model.

Currently, the alternative method gives good agreement with the original method (except at very early times). I have done some preliminary benchmarking, and I think runtimes and memory consumption are comparable when computing gas histories for large galaxy populations, but this should be verified more carefully.

The plot below compares the old and new version for an example galaxy

dmgas_dt_comparison mgash_comparison

One thing to watch out for with the current implementation is the hard-coding done with the number of integration steps, which is currently done with a global variable:
N_INT_STEPS = 20
I have not tested this very thoroughly. But the performance and especially memory consumption depends sensitively upon this variable.

Closes #100.

@aphearin aphearin requested a review from alexalar October 13, 2025 23:34
@alexalar
Copy link
Collaborator

@aphearin I've reviewed the changes, and they all look good. The new functions make sense to me. The new implementation is more accurate because for a given t_obs it integrates with a fixed number of N_INT_STEPS, while the current implementation uses tarr as the t_table for integration, so it will become inaccurate for the first t_obs values. An advantage of this latter method is that there's only one integration of the SFH, since that same integration is used for all t_obs values, while the new implementation runs a separate integration step for each t_obs, so it should be more memory-intensive. If one is only interested in one t_obs at a time, then the new method is always better.

@aphearin
Copy link
Contributor Author

Thanks @alexalar for reviewing this and putting some thought into it!

An advantage of this latter method is that there's only one integration of the SFH, since that same integration is used for all t_obs values, while the new implementation runs a separate integration step for each t_obs, so it should be more memory-intensive.

Right, yes, this is exactly why I highlighted that I had not done memory testing yet. Without careful testing, I'd be worried about making this change since it could result in memory issues.

I can imagine a more efficient version of this new implementation in which we use lax.scan instead of jax.vmap for the transformation of t_obs-->tarr. But I'd like to do timing and memory footprint tests on the current code before trying to implement that.

How about we just leave this PR open for the time being? I can return to it when I have a little blue sky to work under.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Use jax.grad to calculate dmgas_dt rather than finite difference derivs?

3 participants