For me, it boils down to 3 reasons. One is that the JAX tracing process can get a little slow for more complex programs. By bringing Jaxprs into Rust, we can optimize the compile path more. This doesn't matter for small simple programs, but for larger programs (especially with lots of unrolling) it could matter a lot. The 2nd reason is that I just prefer Rust for making production-ready software. The strong type-system especially allows you to catch many errors at compile time, Python is moving in that direction, but IMO isn't there yet. The 3rd reason is that we want this software to easily integrate with your existing flight software (the control software for your drone/satellite), and using a "systems" language like Rust makes that easier. Eventually, we would love to build out a suite of flight software in Rust that can easily integrate with the simulation but is still flight-ready.
A fun side-effect of this architecture is that when we run your simulation we aren't actually running your code, Python or Rust in the loop. We compile down the whole thing to HLO at runtime, then run that HLO. So we could theoretically support lots of other languages.
Exciting! I have been wanting a Rust port of JAX for a long time! (I mostly care about the ability to write numerical GPU code via XLA, not differentiation)
Is there a crate in the works? I would start using it (and contributing if needed) instantly.
Yes you can use Nox separately today. https://github.com/elodin-sys/elodin/tree/main/libs/nox . It just hasn't been released on crates.io yet, and no promises about stability. I mentioned this in a sibling comment, but we likely will have to rename Nox to something else since someone is squatting the name on crates.io
Curious, why reimplementing jax in rust (nox)?