diff --git a/setup.py b/setup.py index 7bf9b347f6ea02f9c79d20c0e0393ea196bcd0fd..e7b4b5380c955d22e50707cc10453baaac103f98 100755 --- a/setup.py +++ b/setup.py @@ -40,8 +40,9 @@ setup( 'sympy >= 1.9.0', 'scipy >= 1.8.1', 'swiglpk >= 5.0.5', - 'sbmlxdf >= 0.2.7'], - python_requires=">=3.7", + 'sbmlxdf >= 0.2.7', + 'jax >= 0.3.13'], + python_requires=">=3.9", keywords=['modeling', 'standardization', 'SBML'], include_package_data=True, **setup_kwargs diff --git a/xbanalysis/problems/gba_problem.py b/xbanalysis/problems/gba_problem.py index dee8097ba318e7c0c435f57451aa4f4390fb530c..55d9e58a6bb301632edcdb6eba296f040fd5296f 100644 --- a/xbanalysis/problems/gba_problem.py +++ b/xbanalysis/problems/gba_problem.py @@ -46,12 +46,13 @@ class GbaProblem: self.scale_density = 1e-4 # metabolites and enzymes in the model - metab_sids = [s.id for s in xba_model.species.values() if s.sboterm.is_in_branch('SBO:0000247')] - enz_sids = [s.id for s in xba_model.species.values() if s.sboterm.is_in_branch('SBO:0000246')] + metab_sids = [s.id for s in xba_model.species.values() + if s.sboterm.is_in_branch('SBO:0000247')] # optimization variables self.var_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is False] - self.var_enz_sids = enz_sids + self.var_enz_sids = [s.id for s in xba_model.species.values() + if s.sboterm.is_in_branch('SBO:0000246')] self.var_sids = self.var_metab_sids + self.var_enz_sids self.n_vars = len(self.var_sids) self.var_sid2idx = {sid: idx for idx, sid in enumerate(self.var_sids)} diff --git a/xbanalysis/problems/gba_sflux_problem.py b/xbanalysis/problems/gba_sflux_problem.py index a8f985728fb1a09086d05b41241aaa5b72405eda..26a8b09957f013f31ca55260f6e94db94b5a88f4 100644 --- a/xbanalysis/problems/gba_sflux_problem.py +++ b/xbanalysis/problems/gba_sflux_problem.py @@ -17,10 +17,16 @@ import pandas as pd import re import sympy as sp import sbmlxdf +import scipy +import jax +import types +import jax.numpy as jnp from xbanalysis.model.xba_unit_def import XbaUnit from xbanalysis.utils.utils import expand_kineticlaw +jax.config.update("jax_enable_x64", True) + class GbaSfluxProblem: @@ -37,7 +43,8 @@ class GbaSfluxProblem: metab_sids = [s.id for s in xba_model.species.values() if s.sboterm.is_in_branch('SBO:0000247')] self.const_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is True] self.var_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is False] - sid_ribo = [s.id for s in xba_model.species.values() if s.sboterm.is_in_branch('SBO:0000250')][0] + sid_ribo = [s.id for s in xba_model.species.values() + if s.sboterm.is_in_branch('SBO:0000250')][0] self.var_sids = self.var_metab_sids + [sid_ribo] self.sids = self.const_metab_sids + self.var_sids self.sid2idx = {msid: idx for idx, msid in enumerate(self.sids)} @@ -47,37 +54,21 @@ class GbaSfluxProblem: rid_ps_ribo = xba_model.species[sid_ribo].ps_rid self.rids = self.mr_rids + [rid_ps_ribo] self.rid2idx = {mrid: idx for idx, mrid in enumerate(self.rids)} + self.enz_sids = np.array([self.xba_model.reactions[rid].enzyme for rid in self.rids]) self.mw_sids = np.array([xba_model.species[sid].mw for sid in self.sids]) - self.mw_enz = np.array([xba_model.species[xba_model.reactions[rid].enzyme].mw - for rid in self.rids]) - # mw_enz is used for scaling of kcats, weight for protein synthesis needs to be rescaled - s = xba_model.species[sid_ribo] - if hasattr(s, 'rba_process_costs'): - p_len = s.rba_process_costs['translation'] - else: - r = xba_model.reactions[rid_ps_ribo] - p_len = min(r.reactants.values()) - self.mw_enz[-1] *= p_len + self.mw_enz = np.array([xba_model.species[enz_sid].mw for enz_sid in self.enz_sids]) - # retrieve S-matrix: internal metabos + total protein and metabolic reaction + ribosome synthesis + # retrieve S-matrix and calulate mass turnover in gram reactants/products per mol reactions sids_x_rids = xba_model.get_stoic_matrix(self.sids, self.rids) + mass_turnover = self.get_mass_turnover(sids_x_rids) + self.rt_conv_factor = 3600.0 * mass_turnover / self.mw_enz - # calulate mass turnover in gram reactants/products per mol reactions (g/mol) - sids_x_rids_react = np.where(sids_x_rids > 0, 0, sids_x_rids) - sids_x_rids_prod = np.where(sids_x_rids < 0, 0, sids_x_rids) - mass_turnover_react = -sids_x_rids_react.T.dot(self.mw_sids) - mass_turnover_prod = sids_x_rids_prod.T.dot(self.mw_sids) - self.mass_turnover = np.max((mass_turnover_react, mass_turnover_prod), axis=0) - - # mass normalized stoichiometric matrix - self.sidx_x_rids_mn = (sids_x_rids * self.mw_sids.reshape(-1, 1)) / self.mass_turnover - - self.rho = xba_model.parameters['rho'].value - self.ext_conc_mn = np.array([xba_model.species[sid].initial_conc * xba_model.species[sid].mw - for sid in self.const_metab_sids]) - self.var_initial_mn = np.array([xba_model.species[sid].initial_conc * xba_model.species[sid].mw - for sid in self.var_metab_sids]) + # mass normalized stoichiometric matrices with/without external metabolites + self.sids_x_rids_mn = (sids_x_rids * self.mw_sids.reshape(-1, 1)) / mass_turnover + self.var_sids_x_rids_mn = self.sids_x_rids_mn[-len(self.var_sids):, ] + self.colsum_var_sids_x_rids_mn = np.sum(self.var_sids_x_rids_mn, axis=0) + self.dof = self.var_sids_x_rids_mn.shape[1] - np.linalg.matrix_rank(self.var_sids_x_rids_mn) # identify unit ids for kinetic parameters units_per_s = [XbaUnit({'kind': 'second', 'exp': -1.0, 'scale': 0, 'mult': 1.0})] @@ -90,6 +81,125 @@ class GbaSfluxProblem: self.kms_global = {pid: p.value for pid, p in xba_model.parameters.items() if p.units == self.km_udid} + # model parameters + self.model_params = {pid: p.value for pid, p in xba_model.parameters.items()} + self.model_params['ext_conc_mn'] = np.array([xba_model.species[sid].initial_conc * + xba_model.species[sid].mw + for sid in self.const_metab_sids]) + self.model_params['var_initial_mn'] = np.array([xba_model.species[sid].initial_conc * + xba_model.species[sid].mw + for sid in self.var_metab_sids]) + + # create functions used for optimization + self.jrts = self.get_reaction_times(self.rids) + self.obj_mu = jax.jit(self.jmu) + self.obj_mu_grad = jax.jit(jax.grad(self.jmu)) + self.obj_mu_hess = jax.jit(jax.jacrev(jax.jacfwd(self.jmu))) + self.constr_density = jax.jit(self.jg) + self.constr_density_grad = jax.jit(jax.grad(self.jg)) + self.constr_density_hess = jax.jit(jax.jacrev(jax.jacfwd(self.jg))) + self.constrs_conc = jax.jit(self.jconcentrations) + self.constrs_conc_jac = jax.jit(jax.jacfwd(self.jconcentrations)) + self.constrs_conc_hesss = jax.jit(jax.jacrev(jax.jacfwd(self.jconcentrations))) + + self.n_constraints = {'constr_density': 1, + 'constrs_conc': len(self.var_sids) + len(self.enz_sids), + } + + def recompile(self): + # recompiliation of function is necessary after parameters have been changed + self.jrts = self.get_reaction_times(self.rids) + self.obj_mu = jax.jit(self.jmu) + self.obj_mu_grad = jax.jit(jax.grad(self.jmu)) + self.obj_mu_hess = jax.jit(jax.jacrev(jax.jacfwd(self.jmu))) + self.constr_density = jax.jit(self.jg) + self.constr_density_grad = jax.jit(jax.grad(self.jg)) + self.constr_density_hess = jax.jit(jax.jacrev(jax.jacfwd(self.jg))) + self.constrs_conc = jax.jit(self.jconcentrations) + self.constrs_conc_jac = jax.jit(jax.jacfwd(self.jconcentrations)) + self.constrs_conc_hesss = jax.jit(jax.jacrev(jax.jacfwd(self.jconcentrations))) + + def get_w0_initial(self): + p_initial_mn = self.model_params['rho'] - sum(self.model_params['var_initial_mn']) + initial_mn = np.hstack((self.model_params['var_initial_mn'], p_initial_mn)) + w0 = np.ravel(scipy.linalg.pinv(self.var_sids_x_rids_mn).dot(initial_mn) / self.model_params['rho']) + return w0 + + def get_mass_turnover(self, sids_x_rids): + # calculate the turnover in gram reactants/products per mol reactions + # from full stoichiometric matrix (including external metabolites) + sids_x_rids_react = np.where(sids_x_rids > 0, 0, sids_x_rids) + sids_x_rids_prod = np.where(sids_x_rids < 0, 0, sids_x_rids) + mass_turnover_react = -sids_x_rids_react.T.dot(self.mw_sids) + mass_turnover_prod = sids_x_rids_prod.T.dot(self.mw_sids) + return np.max((mass_turnover_react, mass_turnover_prod), axis=0) + + def jtau(self, w, model_params): + ci_mass = self.jci(w, model_params) + ca_mass = jnp.hstack((model_params['ext_conc_mn'], ci_mass)) + ca_molar = jnp.divide(ca_mass, self.mw_sids) + return jnp.divide(self.jrts(ca_molar, model_params), self.rt_conv_factor) + + def jmu(self, w, model_params): + return jnp.divide(self.var_sids_x_rids_mn[-1, -1] * w[-1], self.jtau(w, model_params).dot(w)) + + # Density constraint + # sN[:, b].dot(w[b]) -1 , with b: indices of boundary reactions + def jg(self, w, model_params): + return jnp.dot(self.colsum_var_sids_x_rids_mn, w) - 1.0 + + # protein concentrations + def jp(self, w, model_params): + return self.jmu(w, model_params) * model_params['rho'] * self.jtau(w, model_params) * w + + # internal concentrations "i" of substrates "s" and total protein "p" + def jci(self, w, model_params): + return model_params['rho'] * jnp.dot(self.var_sids_x_rids_mn, w) + + # constraint of non-negative concentrations hin(x) ≥ 0 for nlopt, !! for scipy: ≤ 0 + def jconcentrations(self, w, model_params): + return jnp.hstack((self.jci(w, model_params), self.jp(w, model_params))) + + def _to_vectors(self, rs_str): + """replace species ids by vector elements, compartment and enzyme ids by 1.0 + + Variables converted to x[] + + :param rs_str: reaction strings (extracted from the xba model). + :type rs_str: list of strings + :returns: rs_v_str + :rtype: list of reaction stings with parameters replaced by vector elements + """ + # replace compartment variables with 1.0 + # replace enzyme variables with 1.0 + # replace species ids by vector components + # make model parameters accessible + rs_v_str = [] + for r_v_str in rs_str: + for compartment_id in self.xba_model.compartments: + r_v_str = re.sub(r'\b' + compartment_id + r'\b', str(1.0), r_v_str) + for enz_sid in self.enz_sids: + r_v_str = re.sub(r'\b' + enz_sid + r'\b', str(1.0), r_v_str) + for sid, idx in self.sid2idx.items(): + r_v_str = re.sub(r'\b' + sid + r'\b', f'x[{idx}]', r_v_str) + for param_key in self.model_params: + r_v_str = re.sub(r'\b' + param_key + r'\b', f'model_params["{param_key}"]', r_v_str) + rs_v_str.append(r_v_str) + return rs_v_str + + def get_reaction_times(self, rids): + # molar reaction turnover times + # hard-code used namespace prefix for numpy 'np' to a fixed prefix on 'jnp' + # model parameters in reaction times functions as function parameter + turnover_times_str = [f'1.0/({self.xba_model.reactions[rid].expanded_kl})' + for rid in rids] + vs_str = self._to_vectors(turnover_times_str) + func_code = compile('def turnover_times(x, model_params): ' + 'return np.array([' + ', '.join(vs_str) + '])', + '<string>', 'exec') + jrts = types.FunctionType(func_code.co_consts[0], {'np': jnp}) + return jrts + def get_unit_def_id(self, query_units): for udid, ud in self.xba_model.unit_defs.items(): if ud.is_equivalent(query_units): @@ -128,7 +238,6 @@ class GbaSfluxProblem: :return: :rtype: """ - r = self.xba_model.reactions[rid] kms_local = {pid: val for pid, [val, units] in r.local_params.items() if units == self.km_udid} @@ -166,13 +275,23 @@ class GbaSfluxProblem: :return: gba model mass normalized as per Hugo Dourado :rtype: dict of pandas DataFrames """ - protein_mn = self.rho - sum(self.var_initial_mn) + protein_mn = self.model_params['rho'] - sum(self.model_params['var_initial_mn']) kcats = np.array([self.get_kcats(rid) for rid in self.rids]).T - skcats = kcats * 3600.0 * self.mass_turnover / self.mw_enz - km = np.zeros_like(self.sidx_x_rids_mn) - ki = np.zeros_like(self.sidx_x_rids_mn) + # in the SBML model, the kcat for protein synthesis is scaled by protein lenght + s = self.xba_model.species[self.enz_sids[-1]] + if hasattr(s, 'rba_process_costs'): + p_len = s.rba_process_costs['translation'] + else: + r = self.xba_model.reactions[self.rids[-1]] + p_len = min(r.reactants.values()) + + skcats = kcats * self.rt_conv_factor + skcats[:, -1] = skcats[:, -1] / p_len + + km = np.zeros_like(self.sids_x_rids_mn) + ki = np.zeros_like(self.sids_x_rids_mn) for rid in self.rids: kms, kis = self.get_kms(rid) for sid, val in kms.items(): @@ -182,19 +301,19 @@ class GbaSfluxProblem: # scale Michaelis constants with molecluar weights skm = km * self.mw_sids.reshape(-1, 1) ski = ki * self.mw_sids.reshape(-1, 1) - ska = np.zeros_like(self.sidx_x_rids_mn) + ska = np.zeros_like(self.sids_x_rids_mn) mn_model = { - 'N': pd.DataFrame(self.sidx_x_rids_mn, index=self.sids, columns=self.rids), + 'N': pd.DataFrame(self.sids_x_rids_mn, index=self.sids, columns=self.rids), 'KM': pd.DataFrame(skm, index=self.sids, columns=self.rids), 'KI': pd.DataFrame(ski, index=self.sids, columns=self.rids), 'KA': pd.DataFrame(ska, index=self.sids, columns=self.rids), 'kcat': pd.DataFrame(skcats, index=['kcat_f', 'kcat_b'], columns=self.rids), - 'conditions': pd.DataFrame(np.hstack((self.rho, self.ext_conc_mn)), + 'conditions': pd.DataFrame(np.hstack((self.model_params['rho'], self.model_params['ext_conc_mn'])), index=['rho'] + self.const_metab_sids, columns=[1]), 'lower_c': pd.DataFrame(np.zeros(len(self.var_sids)), index=self.var_sids, columns=['lower']), - 'initial': pd.DataFrame(np.hstack((self.var_initial_mn, protein_mn)), + 'initial': pd.DataFrame(np.hstack((self.model_params['var_initial_mn'], protein_mn)), index=self.var_sids, columns=['initial']), } return mn_model diff --git a/xbanalysis/solvers/__init__.py b/xbanalysis/solvers/__init__.py index 5907eaa15781be96589add06e40e15e15185ef89..4e4ee2ebb3ce25a72131cc942ddd2729f5b1ae21 100644 --- a/xbanalysis/solvers/__init__.py +++ b/xbanalysis/solvers/__init__.py @@ -2,5 +2,6 @@ from .glpk_linear_problem import GlpkLinearProblem from .gba_ipopt_problem import GbaIpoptProblem +from .gba_sflux_ipopt_problem import GbaSfluxIpoptProblem -__all__ = ['GlpkLinearProblem', 'GbaIpoptProblem'] +__all__ = ['GlpkLinearProblem', 'GbaIpoptProblem', 'GbaSfluxIpoptProblem'] diff --git a/xbanalysis/solvers/gba_sflux_ipopt_problem.py b/xbanalysis/solvers/gba_sflux_ipopt_problem.py new file mode 100644 index 0000000000000000000000000000000000000000..49718a3b9277bf112eabfb227ce2ba3f74850de3 --- /dev/null +++ b/xbanalysis/solvers/gba_sflux_ipopt_problem.py @@ -0,0 +1,75 @@ +"""Implementation of GbaSfluxIpoptProblem class. + +For solving GBA problem using scaled fluxes as per Hugo Dourado + +Peter Schubert, HHU Duesseldorf, September 2022 +""" + +import numpy as np + + +# based on cyipopt examples +class GbaSfluxIpoptProblem: + + def __init__(self, gba_sflux_problem, model_params=None): + """Initialize sFluxIpoptProblem. + + :param gba_sflux_problem: GbaSfluxProblem from where to extract relevant parameters + :type gba_sflux_problem: GbaSfluxProblem + :param model_params: model parameters (optional, default None, used params from GbaSfluxProblem + :type model_params: None or dict + """ + self.sflux_problem = gba_sflux_problem + if model_params is None: + self.model_params = gba_sflux_problem.model_params + else: + self.model_params = model_params + self._n_vars = len(gba_sflux_problem.rids) + self.report_freq = 0 + + self.nfev = 0 + self.njev = 0 + self.nit = 0 + self.constraint_dims = sum(gba_sflux_problem.n_constraints.values()) + + def objective(self, x): + self.nfev += 1 + return -np.float64(self.sflux_problem.obj_mu(x, self.model_params)) + + def gradient(self, x): + self.njev += 1 + return -np.array(self.sflux_problem.obj_mu_grad(x, self.model_params)) + + def constraints(self, x): + # TODO: improve code + eqs = np.hstack((self.sflux_problem.constr_density(x, self.model_params), + self.sflux_problem.constrs_conc(x, self.model_params))) + return eqs + + def jacobian(self, x): + # TODO: improve code + jacs = (np.vstack((self.sflux_problem.constr_density_grad(x, self.model_params), + self.sflux_problem.constrs_conc_jac(x, self.model_params))) + ).flatten() + return jacs + + # def hessianstructure(self): + # return np.tril_indices(self._n_vars) + + def hessian(self, x, lagrange, obj_factor): + hess = obj_factor * -self.sflux_problem.obj_mu_hess(x, self.model_params) + hesss = np.vstack(([self.sflux_problem.constr_density_hess(x, self.model_params)], + self.sflux_problem.constrs_conc_hesss(x, self.model_params))) + hess += (lagrange.reshape((-1, 1, 1)) * hesss).sum(axis=0) + return hess[np.tril_indices(self._n_vars)] + + def set_report_freq(self, mod_val=0): + self.report_freq = mod_val + + # noinspection PyUnusedLocal + def intermediate(self, alg_mod, iter_count, obj_value, inf_pr, inf_du, mu, + d_norm, regularization_size, alpha_du, alpha_pr, ls_trials): + self.nit = iter_count + if self.report_freq > 0: + if iter_count % self.report_freq == 0: + print(f'[{iter_count:5d}] growth rate: {-obj_value:.4f} 1/h')