Skip to content
Snippets Groups Projects
Commit 996c0218 authored by Peter Schubert's avatar Peter Schubert
Browse files

gba_sflux_problem solved via ipopt using jax autograd and jit

parent abd06b7d
No related branches found
No related tags found
No related merge requests found
...@@ -40,8 +40,9 @@ setup( ...@@ -40,8 +40,9 @@ setup(
'sympy >= 1.9.0', 'sympy >= 1.9.0',
'scipy >= 1.8.1', 'scipy >= 1.8.1',
'swiglpk >= 5.0.5', 'swiglpk >= 5.0.5',
'sbmlxdf >= 0.2.7'], 'sbmlxdf >= 0.2.7',
python_requires=">=3.7", 'jax >= 0.3.13'],
python_requires=">=3.9",
keywords=['modeling', 'standardization', 'SBML'], keywords=['modeling', 'standardization', 'SBML'],
include_package_data=True, include_package_data=True,
**setup_kwargs **setup_kwargs
......
...@@ -46,12 +46,13 @@ class GbaProblem: ...@@ -46,12 +46,13 @@ class GbaProblem:
self.scale_density = 1e-4 self.scale_density = 1e-4
# metabolites and enzymes in the model # metabolites and enzymes in the model
metab_sids = [s.id for s in xba_model.species.values() if s.sboterm.is_in_branch('SBO:0000247')] metab_sids = [s.id for s in xba_model.species.values()
enz_sids = [s.id for s in xba_model.species.values() if s.sboterm.is_in_branch('SBO:0000246')] if s.sboterm.is_in_branch('SBO:0000247')]
# optimization variables # optimization variables
self.var_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is False] self.var_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is False]
self.var_enz_sids = enz_sids self.var_enz_sids = [s.id for s in xba_model.species.values()
if s.sboterm.is_in_branch('SBO:0000246')]
self.var_sids = self.var_metab_sids + self.var_enz_sids self.var_sids = self.var_metab_sids + self.var_enz_sids
self.n_vars = len(self.var_sids) self.n_vars = len(self.var_sids)
self.var_sid2idx = {sid: idx for idx, sid in enumerate(self.var_sids)} self.var_sid2idx = {sid: idx for idx, sid in enumerate(self.var_sids)}
......
...@@ -17,10 +17,16 @@ import pandas as pd ...@@ -17,10 +17,16 @@ import pandas as pd
import re import re
import sympy as sp import sympy as sp
import sbmlxdf import sbmlxdf
import scipy
import jax
import types
import jax.numpy as jnp
from xbanalysis.model.xba_unit_def import XbaUnit from xbanalysis.model.xba_unit_def import XbaUnit
from xbanalysis.utils.utils import expand_kineticlaw from xbanalysis.utils.utils import expand_kineticlaw
jax.config.update("jax_enable_x64", True)
class GbaSfluxProblem: class GbaSfluxProblem:
...@@ -37,7 +43,8 @@ class GbaSfluxProblem: ...@@ -37,7 +43,8 @@ class GbaSfluxProblem:
metab_sids = [s.id for s in xba_model.species.values() if s.sboterm.is_in_branch('SBO:0000247')] metab_sids = [s.id for s in xba_model.species.values() if s.sboterm.is_in_branch('SBO:0000247')]
self.const_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is True] self.const_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is True]
self.var_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is False] self.var_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is False]
sid_ribo = [s.id for s in xba_model.species.values() if s.sboterm.is_in_branch('SBO:0000250')][0] sid_ribo = [s.id for s in xba_model.species.values()
if s.sboterm.is_in_branch('SBO:0000250')][0]
self.var_sids = self.var_metab_sids + [sid_ribo] self.var_sids = self.var_metab_sids + [sid_ribo]
self.sids = self.const_metab_sids + self.var_sids self.sids = self.const_metab_sids + self.var_sids
self.sid2idx = {msid: idx for idx, msid in enumerate(self.sids)} self.sid2idx = {msid: idx for idx, msid in enumerate(self.sids)}
...@@ -47,37 +54,21 @@ class GbaSfluxProblem: ...@@ -47,37 +54,21 @@ class GbaSfluxProblem:
rid_ps_ribo = xba_model.species[sid_ribo].ps_rid rid_ps_ribo = xba_model.species[sid_ribo].ps_rid
self.rids = self.mr_rids + [rid_ps_ribo] self.rids = self.mr_rids + [rid_ps_ribo]
self.rid2idx = {mrid: idx for idx, mrid in enumerate(self.rids)} self.rid2idx = {mrid: idx for idx, mrid in enumerate(self.rids)}
self.enz_sids = np.array([self.xba_model.reactions[rid].enzyme for rid in self.rids])
self.mw_sids = np.array([xba_model.species[sid].mw for sid in self.sids]) self.mw_sids = np.array([xba_model.species[sid].mw for sid in self.sids])
self.mw_enz = np.array([xba_model.species[xba_model.reactions[rid].enzyme].mw self.mw_enz = np.array([xba_model.species[enz_sid].mw for enz_sid in self.enz_sids])
for rid in self.rids])
# mw_enz is used for scaling of kcats, weight for protein synthesis needs to be rescaled
s = xba_model.species[sid_ribo]
if hasattr(s, 'rba_process_costs'):
p_len = s.rba_process_costs['translation']
else:
r = xba_model.reactions[rid_ps_ribo]
p_len = min(r.reactants.values())
self.mw_enz[-1] *= p_len
# retrieve S-matrix: internal metabos + total protein and metabolic reaction + ribosome synthesis # retrieve S-matrix and calulate mass turnover in gram reactants/products per mol reactions
sids_x_rids = xba_model.get_stoic_matrix(self.sids, self.rids) sids_x_rids = xba_model.get_stoic_matrix(self.sids, self.rids)
mass_turnover = self.get_mass_turnover(sids_x_rids)
self.rt_conv_factor = 3600.0 * mass_turnover / self.mw_enz
# calulate mass turnover in gram reactants/products per mol reactions (g/mol) # mass normalized stoichiometric matrices with/without external metabolites
sids_x_rids_react = np.where(sids_x_rids > 0, 0, sids_x_rids) self.sids_x_rids_mn = (sids_x_rids * self.mw_sids.reshape(-1, 1)) / mass_turnover
sids_x_rids_prod = np.where(sids_x_rids < 0, 0, sids_x_rids) self.var_sids_x_rids_mn = self.sids_x_rids_mn[-len(self.var_sids):, ]
mass_turnover_react = -sids_x_rids_react.T.dot(self.mw_sids) self.colsum_var_sids_x_rids_mn = np.sum(self.var_sids_x_rids_mn, axis=0)
mass_turnover_prod = sids_x_rids_prod.T.dot(self.mw_sids) self.dof = self.var_sids_x_rids_mn.shape[1] - np.linalg.matrix_rank(self.var_sids_x_rids_mn)
self.mass_turnover = np.max((mass_turnover_react, mass_turnover_prod), axis=0)
# mass normalized stoichiometric matrix
self.sidx_x_rids_mn = (sids_x_rids * self.mw_sids.reshape(-1, 1)) / self.mass_turnover
self.rho = xba_model.parameters['rho'].value
self.ext_conc_mn = np.array([xba_model.species[sid].initial_conc * xba_model.species[sid].mw
for sid in self.const_metab_sids])
self.var_initial_mn = np.array([xba_model.species[sid].initial_conc * xba_model.species[sid].mw
for sid in self.var_metab_sids])
# identify unit ids for kinetic parameters # identify unit ids for kinetic parameters
units_per_s = [XbaUnit({'kind': 'second', 'exp': -1.0, 'scale': 0, 'mult': 1.0})] units_per_s = [XbaUnit({'kind': 'second', 'exp': -1.0, 'scale': 0, 'mult': 1.0})]
...@@ -90,6 +81,125 @@ class GbaSfluxProblem: ...@@ -90,6 +81,125 @@ class GbaSfluxProblem:
self.kms_global = {pid: p.value for pid, p in xba_model.parameters.items() self.kms_global = {pid: p.value for pid, p in xba_model.parameters.items()
if p.units == self.km_udid} if p.units == self.km_udid}
# model parameters
self.model_params = {pid: p.value for pid, p in xba_model.parameters.items()}
self.model_params['ext_conc_mn'] = np.array([xba_model.species[sid].initial_conc *
xba_model.species[sid].mw
for sid in self.const_metab_sids])
self.model_params['var_initial_mn'] = np.array([xba_model.species[sid].initial_conc *
xba_model.species[sid].mw
for sid in self.var_metab_sids])
# create functions used for optimization
self.jrts = self.get_reaction_times(self.rids)
self.obj_mu = jax.jit(self.jmu)
self.obj_mu_grad = jax.jit(jax.grad(self.jmu))
self.obj_mu_hess = jax.jit(jax.jacrev(jax.jacfwd(self.jmu)))
self.constr_density = jax.jit(self.jg)
self.constr_density_grad = jax.jit(jax.grad(self.jg))
self.constr_density_hess = jax.jit(jax.jacrev(jax.jacfwd(self.jg)))
self.constrs_conc = jax.jit(self.jconcentrations)
self.constrs_conc_jac = jax.jit(jax.jacfwd(self.jconcentrations))
self.constrs_conc_hesss = jax.jit(jax.jacrev(jax.jacfwd(self.jconcentrations)))
self.n_constraints = {'constr_density': 1,
'constrs_conc': len(self.var_sids) + len(self.enz_sids),
}
def recompile(self):
# recompiliation of function is necessary after parameters have been changed
self.jrts = self.get_reaction_times(self.rids)
self.obj_mu = jax.jit(self.jmu)
self.obj_mu_grad = jax.jit(jax.grad(self.jmu))
self.obj_mu_hess = jax.jit(jax.jacrev(jax.jacfwd(self.jmu)))
self.constr_density = jax.jit(self.jg)
self.constr_density_grad = jax.jit(jax.grad(self.jg))
self.constr_density_hess = jax.jit(jax.jacrev(jax.jacfwd(self.jg)))
self.constrs_conc = jax.jit(self.jconcentrations)
self.constrs_conc_jac = jax.jit(jax.jacfwd(self.jconcentrations))
self.constrs_conc_hesss = jax.jit(jax.jacrev(jax.jacfwd(self.jconcentrations)))
def get_w0_initial(self):
p_initial_mn = self.model_params['rho'] - sum(self.model_params['var_initial_mn'])
initial_mn = np.hstack((self.model_params['var_initial_mn'], p_initial_mn))
w0 = np.ravel(scipy.linalg.pinv(self.var_sids_x_rids_mn).dot(initial_mn) / self.model_params['rho'])
return w0
def get_mass_turnover(self, sids_x_rids):
# calculate the turnover in gram reactants/products per mol reactions
# from full stoichiometric matrix (including external metabolites)
sids_x_rids_react = np.where(sids_x_rids > 0, 0, sids_x_rids)
sids_x_rids_prod = np.where(sids_x_rids < 0, 0, sids_x_rids)
mass_turnover_react = -sids_x_rids_react.T.dot(self.mw_sids)
mass_turnover_prod = sids_x_rids_prod.T.dot(self.mw_sids)
return np.max((mass_turnover_react, mass_turnover_prod), axis=0)
def jtau(self, w, model_params):
ci_mass = self.jci(w, model_params)
ca_mass = jnp.hstack((model_params['ext_conc_mn'], ci_mass))
ca_molar = jnp.divide(ca_mass, self.mw_sids)
return jnp.divide(self.jrts(ca_molar, model_params), self.rt_conv_factor)
def jmu(self, w, model_params):
return jnp.divide(self.var_sids_x_rids_mn[-1, -1] * w[-1], self.jtau(w, model_params).dot(w))
# Density constraint
# sN[:, b].dot(w[b]) -1 , with b: indices of boundary reactions
def jg(self, w, model_params):
return jnp.dot(self.colsum_var_sids_x_rids_mn, w) - 1.0
# protein concentrations
def jp(self, w, model_params):
return self.jmu(w, model_params) * model_params['rho'] * self.jtau(w, model_params) * w
# internal concentrations "i" of substrates "s" and total protein "p"
def jci(self, w, model_params):
return model_params['rho'] * jnp.dot(self.var_sids_x_rids_mn, w)
# constraint of non-negative concentrations hin(x) ≥ 0 for nlopt, !! for scipy: ≤ 0
def jconcentrations(self, w, model_params):
return jnp.hstack((self.jci(w, model_params), self.jp(w, model_params)))
def _to_vectors(self, rs_str):
"""replace species ids by vector elements, compartment and enzyme ids by 1.0
Variables converted to x[]
:param rs_str: reaction strings (extracted from the xba model).
:type rs_str: list of strings
:returns: rs_v_str
:rtype: list of reaction stings with parameters replaced by vector elements
"""
# replace compartment variables with 1.0
# replace enzyme variables with 1.0
# replace species ids by vector components
# make model parameters accessible
rs_v_str = []
for r_v_str in rs_str:
for compartment_id in self.xba_model.compartments:
r_v_str = re.sub(r'\b' + compartment_id + r'\b', str(1.0), r_v_str)
for enz_sid in self.enz_sids:
r_v_str = re.sub(r'\b' + enz_sid + r'\b', str(1.0), r_v_str)
for sid, idx in self.sid2idx.items():
r_v_str = re.sub(r'\b' + sid + r'\b', f'x[{idx}]', r_v_str)
for param_key in self.model_params:
r_v_str = re.sub(r'\b' + param_key + r'\b', f'model_params["{param_key}"]', r_v_str)
rs_v_str.append(r_v_str)
return rs_v_str
def get_reaction_times(self, rids):
# molar reaction turnover times
# hard-code used namespace prefix for numpy 'np' to a fixed prefix on 'jnp'
# model parameters in reaction times functions as function parameter
turnover_times_str = [f'1.0/({self.xba_model.reactions[rid].expanded_kl})'
for rid in rids]
vs_str = self._to_vectors(turnover_times_str)
func_code = compile('def turnover_times(x, model_params): '
'return np.array([' + ', '.join(vs_str) + '])',
'<string>', 'exec')
jrts = types.FunctionType(func_code.co_consts[0], {'np': jnp})
return jrts
def get_unit_def_id(self, query_units): def get_unit_def_id(self, query_units):
for udid, ud in self.xba_model.unit_defs.items(): for udid, ud in self.xba_model.unit_defs.items():
if ud.is_equivalent(query_units): if ud.is_equivalent(query_units):
...@@ -128,7 +238,6 @@ class GbaSfluxProblem: ...@@ -128,7 +238,6 @@ class GbaSfluxProblem:
:return: :return:
:rtype: :rtype:
""" """
r = self.xba_model.reactions[rid] r = self.xba_model.reactions[rid]
kms_local = {pid: val for pid, [val, units] in r.local_params.items() if units == self.km_udid} kms_local = {pid: val for pid, [val, units] in r.local_params.items() if units == self.km_udid}
...@@ -166,13 +275,23 @@ class GbaSfluxProblem: ...@@ -166,13 +275,23 @@ class GbaSfluxProblem:
:return: gba model mass normalized as per Hugo Dourado :return: gba model mass normalized as per Hugo Dourado
:rtype: dict of pandas DataFrames :rtype: dict of pandas DataFrames
""" """
protein_mn = self.rho - sum(self.var_initial_mn) protein_mn = self.model_params['rho'] - sum(self.model_params['var_initial_mn'])
kcats = np.array([self.get_kcats(rid) for rid in self.rids]).T kcats = np.array([self.get_kcats(rid) for rid in self.rids]).T
skcats = kcats * 3600.0 * self.mass_turnover / self.mw_enz
km = np.zeros_like(self.sidx_x_rids_mn) # in the SBML model, the kcat for protein synthesis is scaled by protein lenght
ki = np.zeros_like(self.sidx_x_rids_mn) s = self.xba_model.species[self.enz_sids[-1]]
if hasattr(s, 'rba_process_costs'):
p_len = s.rba_process_costs['translation']
else:
r = self.xba_model.reactions[self.rids[-1]]
p_len = min(r.reactants.values())
skcats = kcats * self.rt_conv_factor
skcats[:, -1] = skcats[:, -1] / p_len
km = np.zeros_like(self.sids_x_rids_mn)
ki = np.zeros_like(self.sids_x_rids_mn)
for rid in self.rids: for rid in self.rids:
kms, kis = self.get_kms(rid) kms, kis = self.get_kms(rid)
for sid, val in kms.items(): for sid, val in kms.items():
...@@ -182,19 +301,19 @@ class GbaSfluxProblem: ...@@ -182,19 +301,19 @@ class GbaSfluxProblem:
# scale Michaelis constants with molecluar weights # scale Michaelis constants with molecluar weights
skm = km * self.mw_sids.reshape(-1, 1) skm = km * self.mw_sids.reshape(-1, 1)
ski = ki * self.mw_sids.reshape(-1, 1) ski = ki * self.mw_sids.reshape(-1, 1)
ska = np.zeros_like(self.sidx_x_rids_mn) ska = np.zeros_like(self.sids_x_rids_mn)
mn_model = { mn_model = {
'N': pd.DataFrame(self.sidx_x_rids_mn, index=self.sids, columns=self.rids), 'N': pd.DataFrame(self.sids_x_rids_mn, index=self.sids, columns=self.rids),
'KM': pd.DataFrame(skm, index=self.sids, columns=self.rids), 'KM': pd.DataFrame(skm, index=self.sids, columns=self.rids),
'KI': pd.DataFrame(ski, index=self.sids, columns=self.rids), 'KI': pd.DataFrame(ski, index=self.sids, columns=self.rids),
'KA': pd.DataFrame(ska, index=self.sids, columns=self.rids), 'KA': pd.DataFrame(ska, index=self.sids, columns=self.rids),
'kcat': pd.DataFrame(skcats, index=['kcat_f', 'kcat_b'], columns=self.rids), 'kcat': pd.DataFrame(skcats, index=['kcat_f', 'kcat_b'], columns=self.rids),
'conditions': pd.DataFrame(np.hstack((self.rho, self.ext_conc_mn)), 'conditions': pd.DataFrame(np.hstack((self.model_params['rho'], self.model_params['ext_conc_mn'])),
index=['rho'] + self.const_metab_sids, columns=[1]), index=['rho'] + self.const_metab_sids, columns=[1]),
'lower_c': pd.DataFrame(np.zeros(len(self.var_sids)), 'lower_c': pd.DataFrame(np.zeros(len(self.var_sids)),
index=self.var_sids, columns=['lower']), index=self.var_sids, columns=['lower']),
'initial': pd.DataFrame(np.hstack((self.var_initial_mn, protein_mn)), 'initial': pd.DataFrame(np.hstack((self.model_params['var_initial_mn'], protein_mn)),
index=self.var_sids, columns=['initial']), index=self.var_sids, columns=['initial']),
} }
return mn_model return mn_model
...@@ -2,5 +2,6 @@ ...@@ -2,5 +2,6 @@
from .glpk_linear_problem import GlpkLinearProblem from .glpk_linear_problem import GlpkLinearProblem
from .gba_ipopt_problem import GbaIpoptProblem from .gba_ipopt_problem import GbaIpoptProblem
from .gba_sflux_ipopt_problem import GbaSfluxIpoptProblem
__all__ = ['GlpkLinearProblem', 'GbaIpoptProblem'] __all__ = ['GlpkLinearProblem', 'GbaIpoptProblem', 'GbaSfluxIpoptProblem']
"""Implementation of GbaSfluxIpoptProblem class.
For solving GBA problem using scaled fluxes as per Hugo Dourado
Peter Schubert, HHU Duesseldorf, September 2022
"""
import numpy as np
# based on cyipopt examples
class GbaSfluxIpoptProblem:
def __init__(self, gba_sflux_problem, model_params=None):
"""Initialize sFluxIpoptProblem.
:param gba_sflux_problem: GbaSfluxProblem from where to extract relevant parameters
:type gba_sflux_problem: GbaSfluxProblem
:param model_params: model parameters (optional, default None, used params from GbaSfluxProblem
:type model_params: None or dict
"""
self.sflux_problem = gba_sflux_problem
if model_params is None:
self.model_params = gba_sflux_problem.model_params
else:
self.model_params = model_params
self._n_vars = len(gba_sflux_problem.rids)
self.report_freq = 0
self.nfev = 0
self.njev = 0
self.nit = 0
self.constraint_dims = sum(gba_sflux_problem.n_constraints.values())
def objective(self, x):
self.nfev += 1
return -np.float64(self.sflux_problem.obj_mu(x, self.model_params))
def gradient(self, x):
self.njev += 1
return -np.array(self.sflux_problem.obj_mu_grad(x, self.model_params))
def constraints(self, x):
# TODO: improve code
eqs = np.hstack((self.sflux_problem.constr_density(x, self.model_params),
self.sflux_problem.constrs_conc(x, self.model_params)))
return eqs
def jacobian(self, x):
# TODO: improve code
jacs = (np.vstack((self.sflux_problem.constr_density_grad(x, self.model_params),
self.sflux_problem.constrs_conc_jac(x, self.model_params)))
).flatten()
return jacs
# def hessianstructure(self):
# return np.tril_indices(self._n_vars)
def hessian(self, x, lagrange, obj_factor):
hess = obj_factor * -self.sflux_problem.obj_mu_hess(x, self.model_params)
hesss = np.vstack(([self.sflux_problem.constr_density_hess(x, self.model_params)],
self.sflux_problem.constrs_conc_hesss(x, self.model_params)))
hess += (lagrange.reshape((-1, 1, 1)) * hesss).sum(axis=0)
return hess[np.tril_indices(self._n_vars)]
def set_report_freq(self, mod_val=0):
self.report_freq = mod_val
# noinspection PyUnusedLocal
def intermediate(self, alg_mod, iter_count, obj_value, inf_pr, inf_du, mu,
d_norm, regularization_size, alpha_du, alpha_pr, ls_trials):
self.nit = iter_count
if self.report_freq > 0:
if iter_count % self.report_freq == 0:
print(f'[{iter_count:5d}] growth rate: {-obj_value:.4f} 1/h')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment