Skip to content
Snippets Groups Projects
Commit 996c0218 authored by Peter Schubert's avatar Peter Schubert
Browse files

gba_sflux_problem solved via ipopt using jax autograd and jit

parent abd06b7d
No related branches found
No related tags found
No related merge requests found
......@@ -40,8 +40,9 @@ setup(
'sympy >= 1.9.0',
'scipy >= 1.8.1',
'swiglpk >= 5.0.5',
'sbmlxdf >= 0.2.7'],
python_requires=">=3.7",
'sbmlxdf >= 0.2.7',
'jax >= 0.3.13'],
python_requires=">=3.9",
keywords=['modeling', 'standardization', 'SBML'],
include_package_data=True,
**setup_kwargs
......
......@@ -46,12 +46,13 @@ class GbaProblem:
self.scale_density = 1e-4
# metabolites and enzymes in the model
metab_sids = [s.id for s in xba_model.species.values() if s.sboterm.is_in_branch('SBO:0000247')]
enz_sids = [s.id for s in xba_model.species.values() if s.sboterm.is_in_branch('SBO:0000246')]
metab_sids = [s.id for s in xba_model.species.values()
if s.sboterm.is_in_branch('SBO:0000247')]
# optimization variables
self.var_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is False]
self.var_enz_sids = enz_sids
self.var_enz_sids = [s.id for s in xba_model.species.values()
if s.sboterm.is_in_branch('SBO:0000246')]
self.var_sids = self.var_metab_sids + self.var_enz_sids
self.n_vars = len(self.var_sids)
self.var_sid2idx = {sid: idx for idx, sid in enumerate(self.var_sids)}
......
......@@ -17,10 +17,16 @@ import pandas as pd
import re
import sympy as sp
import sbmlxdf
import scipy
import jax
import types
import jax.numpy as jnp
from xbanalysis.model.xba_unit_def import XbaUnit
from xbanalysis.utils.utils import expand_kineticlaw
jax.config.update("jax_enable_x64", True)
class GbaSfluxProblem:
......@@ -37,7 +43,8 @@ class GbaSfluxProblem:
metab_sids = [s.id for s in xba_model.species.values() if s.sboterm.is_in_branch('SBO:0000247')]
self.const_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is True]
self.var_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is False]
sid_ribo = [s.id for s in xba_model.species.values() if s.sboterm.is_in_branch('SBO:0000250')][0]
sid_ribo = [s.id for s in xba_model.species.values()
if s.sboterm.is_in_branch('SBO:0000250')][0]
self.var_sids = self.var_metab_sids + [sid_ribo]
self.sids = self.const_metab_sids + self.var_sids
self.sid2idx = {msid: idx for idx, msid in enumerate(self.sids)}
......@@ -47,37 +54,21 @@ class GbaSfluxProblem:
rid_ps_ribo = xba_model.species[sid_ribo].ps_rid
self.rids = self.mr_rids + [rid_ps_ribo]
self.rid2idx = {mrid: idx for idx, mrid in enumerate(self.rids)}
self.enz_sids = np.array([self.xba_model.reactions[rid].enzyme for rid in self.rids])
self.mw_sids = np.array([xba_model.species[sid].mw for sid in self.sids])
self.mw_enz = np.array([xba_model.species[xba_model.reactions[rid].enzyme].mw
for rid in self.rids])
# mw_enz is used for scaling of kcats, weight for protein synthesis needs to be rescaled
s = xba_model.species[sid_ribo]
if hasattr(s, 'rba_process_costs'):
p_len = s.rba_process_costs['translation']
else:
r = xba_model.reactions[rid_ps_ribo]
p_len = min(r.reactants.values())
self.mw_enz[-1] *= p_len
self.mw_enz = np.array([xba_model.species[enz_sid].mw for enz_sid in self.enz_sids])
# retrieve S-matrix: internal metabos + total protein and metabolic reaction + ribosome synthesis
# retrieve S-matrix and calulate mass turnover in gram reactants/products per mol reactions
sids_x_rids = xba_model.get_stoic_matrix(self.sids, self.rids)
mass_turnover = self.get_mass_turnover(sids_x_rids)
self.rt_conv_factor = 3600.0 * mass_turnover / self.mw_enz
# calulate mass turnover in gram reactants/products per mol reactions (g/mol)
sids_x_rids_react = np.where(sids_x_rids > 0, 0, sids_x_rids)
sids_x_rids_prod = np.where(sids_x_rids < 0, 0, sids_x_rids)
mass_turnover_react = -sids_x_rids_react.T.dot(self.mw_sids)
mass_turnover_prod = sids_x_rids_prod.T.dot(self.mw_sids)
self.mass_turnover = np.max((mass_turnover_react, mass_turnover_prod), axis=0)
# mass normalized stoichiometric matrix
self.sidx_x_rids_mn = (sids_x_rids * self.mw_sids.reshape(-1, 1)) / self.mass_turnover
self.rho = xba_model.parameters['rho'].value
self.ext_conc_mn = np.array([xba_model.species[sid].initial_conc * xba_model.species[sid].mw
for sid in self.const_metab_sids])
self.var_initial_mn = np.array([xba_model.species[sid].initial_conc * xba_model.species[sid].mw
for sid in self.var_metab_sids])
# mass normalized stoichiometric matrices with/without external metabolites
self.sids_x_rids_mn = (sids_x_rids * self.mw_sids.reshape(-1, 1)) / mass_turnover
self.var_sids_x_rids_mn = self.sids_x_rids_mn[-len(self.var_sids):, ]
self.colsum_var_sids_x_rids_mn = np.sum(self.var_sids_x_rids_mn, axis=0)
self.dof = self.var_sids_x_rids_mn.shape[1] - np.linalg.matrix_rank(self.var_sids_x_rids_mn)
# identify unit ids for kinetic parameters
units_per_s = [XbaUnit({'kind': 'second', 'exp': -1.0, 'scale': 0, 'mult': 1.0})]
......@@ -90,6 +81,125 @@ class GbaSfluxProblem:
self.kms_global = {pid: p.value for pid, p in xba_model.parameters.items()
if p.units == self.km_udid}
# model parameters
self.model_params = {pid: p.value for pid, p in xba_model.parameters.items()}
self.model_params['ext_conc_mn'] = np.array([xba_model.species[sid].initial_conc *
xba_model.species[sid].mw
for sid in self.const_metab_sids])
self.model_params['var_initial_mn'] = np.array([xba_model.species[sid].initial_conc *
xba_model.species[sid].mw
for sid in self.var_metab_sids])
# create functions used for optimization
self.jrts = self.get_reaction_times(self.rids)
self.obj_mu = jax.jit(self.jmu)
self.obj_mu_grad = jax.jit(jax.grad(self.jmu))
self.obj_mu_hess = jax.jit(jax.jacrev(jax.jacfwd(self.jmu)))
self.constr_density = jax.jit(self.jg)
self.constr_density_grad = jax.jit(jax.grad(self.jg))
self.constr_density_hess = jax.jit(jax.jacrev(jax.jacfwd(self.jg)))
self.constrs_conc = jax.jit(self.jconcentrations)
self.constrs_conc_jac = jax.jit(jax.jacfwd(self.jconcentrations))
self.constrs_conc_hesss = jax.jit(jax.jacrev(jax.jacfwd(self.jconcentrations)))
self.n_constraints = {'constr_density': 1,
'constrs_conc': len(self.var_sids) + len(self.enz_sids),
}
def recompile(self):
# recompiliation of function is necessary after parameters have been changed
self.jrts = self.get_reaction_times(self.rids)
self.obj_mu = jax.jit(self.jmu)
self.obj_mu_grad = jax.jit(jax.grad(self.jmu))
self.obj_mu_hess = jax.jit(jax.jacrev(jax.jacfwd(self.jmu)))
self.constr_density = jax.jit(self.jg)
self.constr_density_grad = jax.jit(jax.grad(self.jg))
self.constr_density_hess = jax.jit(jax.jacrev(jax.jacfwd(self.jg)))
self.constrs_conc = jax.jit(self.jconcentrations)
self.constrs_conc_jac = jax.jit(jax.jacfwd(self.jconcentrations))
self.constrs_conc_hesss = jax.jit(jax.jacrev(jax.jacfwd(self.jconcentrations)))
def get_w0_initial(self):
p_initial_mn = self.model_params['rho'] - sum(self.model_params['var_initial_mn'])
initial_mn = np.hstack((self.model_params['var_initial_mn'], p_initial_mn))
w0 = np.ravel(scipy.linalg.pinv(self.var_sids_x_rids_mn).dot(initial_mn) / self.model_params['rho'])
return w0
def get_mass_turnover(self, sids_x_rids):
# calculate the turnover in gram reactants/products per mol reactions
# from full stoichiometric matrix (including external metabolites)
sids_x_rids_react = np.where(sids_x_rids > 0, 0, sids_x_rids)
sids_x_rids_prod = np.where(sids_x_rids < 0, 0, sids_x_rids)
mass_turnover_react = -sids_x_rids_react.T.dot(self.mw_sids)
mass_turnover_prod = sids_x_rids_prod.T.dot(self.mw_sids)
return np.max((mass_turnover_react, mass_turnover_prod), axis=0)
def jtau(self, w, model_params):
ci_mass = self.jci(w, model_params)
ca_mass = jnp.hstack((model_params['ext_conc_mn'], ci_mass))
ca_molar = jnp.divide(ca_mass, self.mw_sids)
return jnp.divide(self.jrts(ca_molar, model_params), self.rt_conv_factor)
def jmu(self, w, model_params):
return jnp.divide(self.var_sids_x_rids_mn[-1, -1] * w[-1], self.jtau(w, model_params).dot(w))
# Density constraint
# sN[:, b].dot(w[b]) -1 , with b: indices of boundary reactions
def jg(self, w, model_params):
return jnp.dot(self.colsum_var_sids_x_rids_mn, w) - 1.0
# protein concentrations
def jp(self, w, model_params):
return self.jmu(w, model_params) * model_params['rho'] * self.jtau(w, model_params) * w
# internal concentrations "i" of substrates "s" and total protein "p"
def jci(self, w, model_params):
return model_params['rho'] * jnp.dot(self.var_sids_x_rids_mn, w)
# constraint of non-negative concentrations hin(x) ≥ 0 for nlopt, !! for scipy: ≤ 0
def jconcentrations(self, w, model_params):
return jnp.hstack((self.jci(w, model_params), self.jp(w, model_params)))
def _to_vectors(self, rs_str):
"""replace species ids by vector elements, compartment and enzyme ids by 1.0
Variables converted to x[]
:param rs_str: reaction strings (extracted from the xba model).
:type rs_str: list of strings
:returns: rs_v_str
:rtype: list of reaction stings with parameters replaced by vector elements
"""
# replace compartment variables with 1.0
# replace enzyme variables with 1.0
# replace species ids by vector components
# make model parameters accessible
rs_v_str = []
for r_v_str in rs_str:
for compartment_id in self.xba_model.compartments:
r_v_str = re.sub(r'\b' + compartment_id + r'\b', str(1.0), r_v_str)
for enz_sid in self.enz_sids:
r_v_str = re.sub(r'\b' + enz_sid + r'\b', str(1.0), r_v_str)
for sid, idx in self.sid2idx.items():
r_v_str = re.sub(r'\b' + sid + r'\b', f'x[{idx}]', r_v_str)
for param_key in self.model_params:
r_v_str = re.sub(r'\b' + param_key + r'\b', f'model_params["{param_key}"]', r_v_str)
rs_v_str.append(r_v_str)
return rs_v_str
def get_reaction_times(self, rids):
# molar reaction turnover times
# hard-code used namespace prefix for numpy 'np' to a fixed prefix on 'jnp'
# model parameters in reaction times functions as function parameter
turnover_times_str = [f'1.0/({self.xba_model.reactions[rid].expanded_kl})'
for rid in rids]
vs_str = self._to_vectors(turnover_times_str)
func_code = compile('def turnover_times(x, model_params): '
'return np.array([' + ', '.join(vs_str) + '])',
'<string>', 'exec')
jrts = types.FunctionType(func_code.co_consts[0], {'np': jnp})
return jrts
def get_unit_def_id(self, query_units):
for udid, ud in self.xba_model.unit_defs.items():
if ud.is_equivalent(query_units):
......@@ -128,7 +238,6 @@ class GbaSfluxProblem:
:return:
:rtype:
"""
r = self.xba_model.reactions[rid]
kms_local = {pid: val for pid, [val, units] in r.local_params.items() if units == self.km_udid}
......@@ -166,13 +275,23 @@ class GbaSfluxProblem:
:return: gba model mass normalized as per Hugo Dourado
:rtype: dict of pandas DataFrames
"""
protein_mn = self.rho - sum(self.var_initial_mn)
protein_mn = self.model_params['rho'] - sum(self.model_params['var_initial_mn'])
kcats = np.array([self.get_kcats(rid) for rid in self.rids]).T
skcats = kcats * 3600.0 * self.mass_turnover / self.mw_enz
km = np.zeros_like(self.sidx_x_rids_mn)
ki = np.zeros_like(self.sidx_x_rids_mn)
# in the SBML model, the kcat for protein synthesis is scaled by protein lenght
s = self.xba_model.species[self.enz_sids[-1]]
if hasattr(s, 'rba_process_costs'):
p_len = s.rba_process_costs['translation']
else:
r = self.xba_model.reactions[self.rids[-1]]
p_len = min(r.reactants.values())
skcats = kcats * self.rt_conv_factor
skcats[:, -1] = skcats[:, -1] / p_len
km = np.zeros_like(self.sids_x_rids_mn)
ki = np.zeros_like(self.sids_x_rids_mn)
for rid in self.rids:
kms, kis = self.get_kms(rid)
for sid, val in kms.items():
......@@ -182,19 +301,19 @@ class GbaSfluxProblem:
# scale Michaelis constants with molecluar weights
skm = km * self.mw_sids.reshape(-1, 1)
ski = ki * self.mw_sids.reshape(-1, 1)
ska = np.zeros_like(self.sidx_x_rids_mn)
ska = np.zeros_like(self.sids_x_rids_mn)
mn_model = {
'N': pd.DataFrame(self.sidx_x_rids_mn, index=self.sids, columns=self.rids),
'N': pd.DataFrame(self.sids_x_rids_mn, index=self.sids, columns=self.rids),
'KM': pd.DataFrame(skm, index=self.sids, columns=self.rids),
'KI': pd.DataFrame(ski, index=self.sids, columns=self.rids),
'KA': pd.DataFrame(ska, index=self.sids, columns=self.rids),
'kcat': pd.DataFrame(skcats, index=['kcat_f', 'kcat_b'], columns=self.rids),
'conditions': pd.DataFrame(np.hstack((self.rho, self.ext_conc_mn)),
'conditions': pd.DataFrame(np.hstack((self.model_params['rho'], self.model_params['ext_conc_mn'])),
index=['rho'] + self.const_metab_sids, columns=[1]),
'lower_c': pd.DataFrame(np.zeros(len(self.var_sids)),
index=self.var_sids, columns=['lower']),
'initial': pd.DataFrame(np.hstack((self.var_initial_mn, protein_mn)),
'initial': pd.DataFrame(np.hstack((self.model_params['var_initial_mn'], protein_mn)),
index=self.var_sids, columns=['initial']),
}
return mn_model
......@@ -2,5 +2,6 @@
from .glpk_linear_problem import GlpkLinearProblem
from .gba_ipopt_problem import GbaIpoptProblem
from .gba_sflux_ipopt_problem import GbaSfluxIpoptProblem
__all__ = ['GlpkLinearProblem', 'GbaIpoptProblem']
__all__ = ['GlpkLinearProblem', 'GbaIpoptProblem', 'GbaSfluxIpoptProblem']
"""Implementation of GbaSfluxIpoptProblem class.
For solving GBA problem using scaled fluxes as per Hugo Dourado
Peter Schubert, HHU Duesseldorf, September 2022
"""
import numpy as np
# based on cyipopt examples
class GbaSfluxIpoptProblem:
def __init__(self, gba_sflux_problem, model_params=None):
"""Initialize sFluxIpoptProblem.
:param gba_sflux_problem: GbaSfluxProblem from where to extract relevant parameters
:type gba_sflux_problem: GbaSfluxProblem
:param model_params: model parameters (optional, default None, used params from GbaSfluxProblem
:type model_params: None or dict
"""
self.sflux_problem = gba_sflux_problem
if model_params is None:
self.model_params = gba_sflux_problem.model_params
else:
self.model_params = model_params
self._n_vars = len(gba_sflux_problem.rids)
self.report_freq = 0
self.nfev = 0
self.njev = 0
self.nit = 0
self.constraint_dims = sum(gba_sflux_problem.n_constraints.values())
def objective(self, x):
self.nfev += 1
return -np.float64(self.sflux_problem.obj_mu(x, self.model_params))
def gradient(self, x):
self.njev += 1
return -np.array(self.sflux_problem.obj_mu_grad(x, self.model_params))
def constraints(self, x):
# TODO: improve code
eqs = np.hstack((self.sflux_problem.constr_density(x, self.model_params),
self.sflux_problem.constrs_conc(x, self.model_params)))
return eqs
def jacobian(self, x):
# TODO: improve code
jacs = (np.vstack((self.sflux_problem.constr_density_grad(x, self.model_params),
self.sflux_problem.constrs_conc_jac(x, self.model_params)))
).flatten()
return jacs
# def hessianstructure(self):
# return np.tril_indices(self._n_vars)
def hessian(self, x, lagrange, obj_factor):
hess = obj_factor * -self.sflux_problem.obj_mu_hess(x, self.model_params)
hesss = np.vstack(([self.sflux_problem.constr_density_hess(x, self.model_params)],
self.sflux_problem.constrs_conc_hesss(x, self.model_params)))
hess += (lagrange.reshape((-1, 1, 1)) * hesss).sum(axis=0)
return hess[np.tril_indices(self._n_vars)]
def set_report_freq(self, mod_val=0):
self.report_freq = mod_val
# noinspection PyUnusedLocal
def intermediate(self, alg_mod, iter_count, obj_value, inf_pr, inf_du, mu,
d_norm, regularization_size, alpha_du, alpha_pr, ls_trials):
self.nit = iter_count
if self.report_freq > 0:
if iter_count % self.report_freq == 0:
print(f'[{iter_count:5d}] growth rate: {-obj_value:.4f} 1/h')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment