Hi I am using pybamm in a parameter estimation format. In the code below model.full_routine
will call the following equation 1000s of times.
current_interpolant = pybamm.Interpolant(t, I, pybamm.t)
param1["Current function [A]"] = current_interpolant
solver = pybamm.CasadiSolver(mode="fast")
model.solution = None
def equation(betas_list, mtx):
def j0_pos(c_e, c_s_surf, c_s_max, T):
#This evaluation cannot currently be used in JAX until PyBamm Interpolation can be used in JAX Solver
res = evaluate_pybamm(betas_list[0], mtx, [c_s_surf / c_s_max])
return res
def j0_neg(c_e, c_s_surf, c_s_max, T):
res = evaluate_pybamm(betas_list[1], mtx, [c_s_surf / c_s_max])
return res
param1["Positive electrode exchange-current density [A.m-2]"] = j0_pos
param1["Negative electrode exchange-current density [A.m-2]"] = j0_neg
sim = pybamm.Simulation(batmodel, parameter_values=param1, solver=solver)
model.solution = sim.solve(t, initial_soc=0.5)
Vpb = model.solution["Voltage [V]"].entries
return Vpb
model.set_equation(equation)
samples, matrix, BIC = model.full_routine(draws = 1000, tolerance = 0)
I would like to try to increase the speed of the simulation runs so I wanted to try two things:
- Avoid reinitializing the simulation each loop
- Inquire about the possibility of using the pybamm interpolation functions with the JaxSolver to help with this.
Any advice on either of those things or increasing the speed of the runs would be greatly appreciated!