More efficient way to update simulation parameters between runs; Make individual runs faster

Hi I am using pybamm in a parameter estimation format. In the code below model.full_routine will call the following equation 1000s of times.

current_interpolant = pybamm.Interpolant(t, I, pybamm.t)
param1["Current function [A]"] = current_interpolant

solver = pybamm.CasadiSolver(mode="fast")
model.solution = None

def equation(betas_list, mtx):


    def j0_pos(c_e, c_s_surf, c_s_max, T):
        #This evaluation cannot currently be used in JAX until PyBamm Interpolation can be used in JAX Solver
        res = evaluate_pybamm(betas_list[0], mtx, [c_s_surf / c_s_max])
        return res

    def j0_neg(c_e, c_s_surf, c_s_max, T):
        res = evaluate_pybamm(betas_list[1], mtx, [c_s_surf / c_s_max])
        return res

    param1["Positive electrode exchange-current density [A.m-2]"] = j0_pos
    param1["Negative electrode exchange-current density [A.m-2]"] = j0_neg

    sim = pybamm.Simulation(batmodel, parameter_values=param1, solver=solver)
    model.solution = sim.solve(t, initial_soc=0.5)

    Vpb = model.solution["Voltage [V]"].entries

    return Vpb


model.set_equation(equation)

samples, matrix, BIC = model.full_routine(draws = 1000, tolerance = 0)

I would like to try to increase the speed of the simulation runs so I wanted to try two things:

  1. Avoid reinitializing the simulation each loop
  2. Inquire about the possibility of using the pybamm interpolation functions with the JaxSolver to help with this.

Any advice on either of those things or increasing the speed of the runs would be greatly appreciated!

In general, using InputParameter is the way to avoid rebuilding the simulation each time. See Sensitivities and data fitting using PyBaMM — PyBaMM v25.1.1 Manual. However, arrays/interpolants as input parameters are not currently supported

1 Like