Array size incompatible with Jaxified IDAKLU solver

I am following the tutorial for exposing the IDAKLU solver to JAX at IDAKLU-JAX interface — PyBaMM v25.4.3.dev11+g34186fe18 Manual
The code tutorial runs fine with output variables that inherently return 1D arrays (Time, Voltage, Loss of lithium inventory) but those resulting in 2d arrays (Cell temperature) cause the code to break. Has anyone run into this issue before and have any potential insights into reducing cell temperature array dimensionality?