JAX Talk: Diffrax https://www.youtube.com/watch?v=Jy5Jw8hNiAQ
`vmap`-able differential equation solving is really cool.
[1]: https://dynamicfieldtheory.org/ [2]: https://github.com/patrick-kidger/equinox
Kidger's thesis is wonderful https://arxiv.org/abs/2202.02435
jax is fun but as effective as i’d like for CPU
Neural ODE reframes this: instead of focusing on the weights, focus on how they change. It sees training as finding a path from untrained to trained state. At each step, it uses ODE solvers to compute the next state, continuing for N steps till it reaches values matching training data. This gives you the solution for the trained network.