From 3ed33387e734c7cdd498a0cc7f4a65c8d711b634 Mon Sep 17 00:00:00 2001 From: Peter Schubert <Peter.Schubert@hhu.de> Date: Fri, 23 Sep 2022 16:24:48 +0200 Subject: [PATCH] gba_stoic_problem solved via ipopt using jax autograd and jit --- xbanalysis/problems/__init__.py | 5 +- xbanalysis/problems/gba_problem.py | 14 +- xbanalysis/problems/gba_sflux_problem.py | 2 +- xbanalysis/problems/gba_stoic_problem.py | 439 ++++++++++++++++++ xbanalysis/problems/rba_problem.py | 4 +- xbanalysis/solvers/__init__.py | 7 +- ..._ipopt_problem.py => ipopt_gba_problem.py} | 8 +- ..._problem.py => ipopt_gba_sflux_problem.py} | 28 +- xbanalysis/solvers/ipopt_gba_stoic_problem.py | 95 ++++ xbanalysis/utils/opt_results.py | 22 +- 10 files changed, 590 insertions(+), 34 deletions(-) create mode 100644 xbanalysis/problems/gba_stoic_problem.py rename xbanalysis/solvers/{gba_ipopt_problem.py => ipopt_gba_problem.py} (95%) rename xbanalysis/solvers/{gba_sflux_ipopt_problem.py => ipopt_gba_sflux_problem.py} (73%) create mode 100644 xbanalysis/solvers/ipopt_gba_stoic_problem.py diff --git a/xbanalysis/problems/__init__.py b/xbanalysis/problems/__init__.py index a43385f..7df3913 100644 --- a/xbanalysis/problems/__init__.py +++ b/xbanalysis/problems/__init__.py @@ -1,8 +1,9 @@ """Subpackage with XBA model classes """ -from .gba_problem import GbaProblem from .rba_problem import RbaProblem from .fba_problem import FbaProblem +from .gba_problem import GbaProblem from .gba_sflux_problem import GbaSfluxProblem +from .gba_stoic_problem import GbaStoicProblem -__all__ = ['GbaProblem', 'FbaProblem', 'RbaProblem', 'GbaSfluxProblem'] +__all__ = ['GbaProblem', 'FbaProblem', 'RbaProblem', 'GbaSfluxProblem', 'GbaStoicProblem'] diff --git a/xbanalysis/problems/gba_problem.py b/xbanalysis/problems/gba_problem.py index 55d9e58..b9354bf 100644 --- a/xbanalysis/problems/gba_problem.py +++ b/xbanalysis/problems/gba_problem.py @@ -59,6 +59,7 @@ class GbaProblem: self.var_mask_metab = [True if sid in self.var_metab_sids else False for sid in self.var_sids] self.var_mask_enz = [True if sid in self.var_enz_sids else False for sid in self.var_sids] self.var_initial = np.array([xba_model.species[sid].initial_conc for sid in self.var_sids]) + self.var_mws = np.array([xba_model.species[sid].mw for sid in self.var_sids]) self.var_vols = np.array([xba_model.compartments[xba_model.species[sid].compartment].size for sid in self.var_sids]) @@ -110,7 +111,6 @@ class GbaProblem: self.model_params['ext_conc'] = np.array([xba_model.species[sid].initial_conc for sid in self.const_metab_sids]) self.model_params['np'] = np - self.model_params['mws'] = np.array([xba_model.species[sid].mw for sid in self.var_sids]) # create functions self.fmras, self.fmras_jac, self.fmras_hesss = self.get_reactions(self.mr_rids, inv=False) @@ -389,7 +389,9 @@ class GbaProblem: considering stoichiometric coefficiens unequal unity for enzyme turnover stoichiometric sub-matrices do not require volume scaling (as Vol_enz gets factored out) - With dilution (negative stoichiometric coefficients in enz_x_mrs), growth rate reduces + With degradation (!!negative stoichiometric coefficients in enz_x_mrs), growth rate reduces + - first calculate the enzyme amount per time being degraded + Equations: gr(x) = gr0(x) * (1 + enz_to_mras(x).dot(pstmas(x))) @@ -714,9 +716,9 @@ class GbaProblem: """ ce = x[self.var_mask_enz] if self.model_params.get('macro', True) is True: - density = self.model_params['mws'][self.var_mask_enz].dot(ce) + density = self.var_mws[self.var_mask_enz].dot(ce) else: - density = self.model_params['mws'].dot(x) + density = self.var_mws.dot(x) return np.array([self.model_params['rho'] - density]) * self.scale_density @cache_last @@ -734,9 +736,9 @@ class GbaProblem: """ if self.model_params.get('macro', True) is True: density_grad = np.zeros(len(x)) - density_grad[self.var_mask_enz] = -self.model_params['mws'][self.var_mask_enz] + density_grad[self.var_mask_enz] = -self.var_mws[self.var_mask_enz] else: - density_grad = -self.model_params['mws'] + density_grad = -self.var_mws return density_grad * self.scale_density # noinspection PyUnusedLocal diff --git a/xbanalysis/problems/gba_sflux_problem.py b/xbanalysis/problems/gba_sflux_problem.py index 26a8b09..672fd34 100644 --- a/xbanalysis/problems/gba_sflux_problem.py +++ b/xbanalysis/problems/gba_sflux_problem.py @@ -54,7 +54,7 @@ class GbaSfluxProblem: rid_ps_ribo = xba_model.species[sid_ribo].ps_rid self.rids = self.mr_rids + [rid_ps_ribo] self.rid2idx = {mrid: idx for idx, mrid in enumerate(self.rids)} - self.enz_sids = np.array([self.xba_model.reactions[rid].enzyme for rid in self.rids]) + self.enz_sids = np.array([xba_model.reactions[rid].enzyme for rid in self.rids]) self.mw_sids = np.array([xba_model.species[sid].mw for sid in self.sids]) self.mw_enz = np.array([xba_model.species[enz_sid].mw for enz_sid in self.enz_sids]) diff --git a/xbanalysis/problems/gba_stoic_problem.py b/xbanalysis/problems/gba_stoic_problem.py new file mode 100644 index 0000000..87f55fa --- /dev/null +++ b/xbanalysis/problems/gba_stoic_problem.py @@ -0,0 +1,439 @@ +"""Implementation of GBA stoic Problem class. + +based on gba_problem +- extended Growth Balance Analysis (GBA) problem from Deniz Sezer +- basic problem uses internal metabolite and enzyme concentrations + as free variables (optimization variables) + it optimizes growth rate under constraints of mass balance and + total density (cellular density or macromolecular density) +- extended formulation takes care about + - protein degradation + - protein interconversion + - spontaneous reactions + +- a related formulation is based on Hugo Dourado scalled fluxes + see gba_sflux_problem.py + +- this implementation uses the JAX package to + - determine first and second derivatives of all functions + - just in time compilation for speed-up + - support of accellerators (e.g. GPUs) - yet to be tested + +Peter Schubert, HHU Duesseldorf, September 2022 +""" + +import numpy as np +import re +import types +import scipy +import jax +import jax.numpy as jnp + +jax.config.update("jax_enable_x64", True) + +# TODO: variable transformation x = exp(y) and solve for unbounded y. start with vector y = log(x) + + +class GbaStoicProblem: + + cache = True + last_values = {} + + def __init__(self, xba_model): + """Initialize GbaStoichProblem from a xba_model. + + :param xba_model: xba model with all required GBA parameters + :type xba_model: XbaModel + """ + self.xba_model = xba_model + + # check to scale density constraint to improve convergence + self.scale_density = 1e-4 + + # metabolites and enzymes in the model + metab_sids = [s.id for s in xba_model.species.values() + if s.sboterm.is_in_branch('SBO:0000247')] + + # optimization variables + self.var_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is False] + self.var_enz_sids = [s.id for s in xba_model.species.values() + if s.sboterm.is_in_branch('SBO:0000246')] + self.var_sids = self.var_metab_sids + self.var_enz_sids + self.n_vars = len(self.var_sids) + self.var_sid2idx = {sid: idx for idx, sid in enumerate(self.var_sids)} + self.var_mask_metab = [True if sid in self.var_metab_sids else False for sid in self.var_sids] + self.var_mask_enz = [True if sid in self.var_enz_sids else False for sid in self.var_sids] + self.var_initial = np.array([xba_model.species[sid].initial_conc for sid in self.var_sids]) + self.var_mws = np.array([xba_model.species[sid].mw for sid in self.var_sids]) + self.var_vols = np.array([xba_model.compartments[xba_model.species[sid].compartment].size + for sid in self.var_sids]) + + # constant concentrations + self.const_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is True] + self.const_sid2idx = {sid: idx for idx, sid in enumerate(self.const_metab_sids)} + + # protein synthesis reactions get ordered according to the order of their enzyme products + # to facilitate matrix multiplication without further remapping + # - here we expect a one-to-one relationship between enzyme and corresponding protein synthesis reaction + self.mr_rids = [r.id for r in xba_model.reactions.values() + if r.sboterm.is_in_branch('SBO:0000167')] + self.ps_rids = [xba_model.species[sid].ps_rid for sid in self.var_enz_sids] + self.mr_vols = np.array([xba_model.compartments[xba_model.reactions[rid].compartment].size + for rid in self.mr_rids]) + self.ps_vols = np.array([xba_model.compartments[xba_model.reactions[rid].compartment].size + for rid in self.ps_rids]) + + # calculate stoichiometric sub-matrices and scale them according to volume + # TODO: sparce matrix storage and calculations + species_x_mrs = xba_model.get_stoic_matrix(self.var_sids, self.mr_rids) + self.dof = species_x_mrs.shape[1] - np.linalg.matrix_rank(species_x_mrs) + metab_x_mrs = xba_model.get_stoic_matrix(self.var_metab_sids, self.mr_rids) + self.metab_x_psrs = xba_model.get_stoic_matrix(self.var_metab_sids, self.ps_rids) + self.enz_x_mrs = xba_model.get_stoic_matrix(self.var_enz_sids, self.mr_rids) + self.enz_vols = self.var_vols[self.var_mask_enz] + metab_vols = self.var_vols[self.var_mask_metab] + + self.metab_x_mras = metab_x_mrs / metab_vols.reshape((-1, 1)) + self.metab_x_psras = self.metab_x_psrs / metab_vols.reshape((-1, 1)) + self.enz_x_mras = self.enz_x_mrs / self.enz_vols.reshape((-1, 1)) + + # identify reactions ids for enzyme degradation and corresponding submatrix + enz_degrad_idxs = self.get_export_idxs(self.enz_x_mrs) + enz_degrad_rids = np.array(self.mr_rids)[enz_degrad_idxs] + self.enz_x_degrads = self.enz_x_mrs[:, enz_degrad_idxs] + + # identify reactions ids for enzyme transitions and corresponding submatrix + # for submatrix perform LU decomposition and remove empty rows from U matrix + enz_trans_idxs = self.get_trans_idxs(self.enz_x_mras) + enz_trans_rids = np.array(self.mr_rids)[enz_trans_idxs] + u = scipy.linalg.lu((self.enz_x_mras[:, enz_trans_idxs]))[2] + self.subenz_x_trans = u[abs(u).sum(axis=1) > 0] + + # model parameters + self.model_params = {pid: p.value for pid, p in xba_model.parameters.items()} + self.model_params.update({cid: c.size for cid, c in xba_model.compartments.items()}) + self.model_params['macro'] = self.model_params.get('macro', 0.0) > 0.0 + self.model_params['ext_conc'] = np.array([xba_model.species[sid].initial_conc + for sid in self.const_metab_sids]) + + # move some numpy structures to jax.numpy + # TODO: possibly code relevant numpy objects in jax.numpy to allow for accelleration + self.jvar_mask_enz = jnp.array(self.var_mask_enz) + self.jenz_x_mrs = jnp.array(self.enz_x_mrs) + self.jvar_mask_metab = jnp.array(self.var_mask_metab) + self.jenz_x_degrads = jnp.array(self.enz_x_degrads) + self.jenz_vols = jnp.array(self.enz_vols) + self.jsubenz_x_trans = jnp.array(self.subenz_x_trans) + self.jvar_mws = jnp.array(self.var_mws) + self.jmetab_x_mras = jnp.array(self.metab_x_mras) + self.jmetab_x_psrs = jnp.array(self.metab_x_psrs) + self.jenz_x_mras = jnp.array(self.enz_x_mras) + self.jmetab_x_psrs = jnp.array(self.metab_x_psrs) + + # retrieve kinetic functions and compile them + self.mras = self.get_reactions(self.mr_rids, inv=False) + self.pstmas = self.get_reactions(self.ps_rids, inv=True) + self.jmras_jit = jax.jit(self.mras) + self.jpstmas_jit = jax.jit(self.pstmas) + self.jdegradas_jit = jax.jit(self.get_reactions(enz_degrad_rids, inv=False)) + self.jenz_tras_jit = jax.jit(self.get_reactions(enz_trans_rids, inv=False)) + + # create functions used for optimization + self.jobj_mu_jit = jax.jit(self.get_growth_rate) + self.jobj_mu_grad_jit = jax.jit(jax.grad(self.jobj_mu_jit)) + self.jobj_mu_hess_jit = jax.jit(jax.jacfwd(self.jobj_mu_grad_jit)) + + self.jeq_density_jit = jax.jit(self.heq_density) + self.jeq_density_jac_jit = jax.jit(jax.jacrev(self.jeq_density_jit)) + self.jeq_density_hesss_jit = jax.jit(jax.jacfwd(self.jeq_density_jac_jit)) + self.jeq_density_macro_jit = jax.jit(self.heq_density_macro) + self.jeq_density_macro_jac_jit = jax.jit(jax.jacrev(self.jeq_density_macro_jit)) + self.jeq_density_macro_hesss_jit = jax.jit(jax.jacfwd(self.jeq_density_macro_jac_jit)) + + self.jeq_mass_balance_jit = jax.jit(self.heq_mass_balance) + self.jeq_mass_balance_jac_jit = jax.jit(jax.jacrev(self.jeq_mass_balance_jit)) + self.jeq_mass_balance_hesss_jit = jax.jit(jax.jacfwd(self.jeq_mass_balance_jac_jit)) + + self.jeq_enz_trans_jit = jax.jit(self.heq_enz_trans) + self.jeq_enz_trans_jac_jit = jax.jit(jax.jacrev(self.jeq_enz_trans_jit)) + self.jeq_enz_trans_hesss_jit = jax.jit(jax.jacfwd(self.jeq_enz_trans_jac_jit)) + + self.n_constraints = {'heq_mass_balance': len(self.var_metab_sids), + 'heq_enz_trans': self.subenz_x_trans.shape[0], + 'heq_density': 1} + + def compile(self, x, model_params): + jx = jnp.array(x) + + self.jmras_jit(jx, model_params) + self.jpstmas_jit(jx, model_params) + self.jdegradas_jit(jx, model_params) + self.jenz_tras_jit(jx, model_params) + + self.jobj_mu_jit(jx, model_params) + self.jobj_mu_grad_jit(jx, model_params) + self.jobj_mu_hess_jit(jx, model_params) + + self.jeq_density_jit(jx, model_params) + self.jeq_density_jac_jit(jx, model_params) + self.jeq_density_hesss_jit(jx, model_params) + self.jeq_density_macro_jit(jx, model_params) + self.jeq_density_macro_jac_jit(jx, model_params) + self.jeq_density_macro_hesss_jit(jx, model_params) + + self.jeq_mass_balance_jit(jx, model_params) + self.jeq_mass_balance_jac_jit(jx, model_params) + self.jeq_mass_balance_hesss_jit(jx, model_params) + + self.jeq_enz_trans_jit(jx, model_params) + self.jeq_enz_trans_jac_jit(jx, model_params) + self.jeq_enz_trans_hesss_jit(jx, model_params) + + def update_model_params(self, key, val): + self.model_params[key] = val + + def _to_vectors(self, rs_str): + """replace variables and constant metabolites by vector elements. + + Variables converted to x[], external metabolites to ext_conc[]. + access to model parameters via model_params dict + + :param rs_str: reaction strings (extracted from the xba model). + :type rs_str: list of strings + :returns: rs_v_str + :rtype: list of reaction stings with parameters replaced by vector elements + """ + rs_v_str = [] + for r_v_str in rs_str: + for idx, sid in enumerate(self.var_sids): + r_v_str = re.sub(r'\b' + sid + r'\b', f'x[{idx}]', r_v_str) + for idx, sid in enumerate(self.const_metab_sids): + r_v_str = re.sub(r'\b' + sid + r'\b', f'ext_conc[{idx}]', r_v_str) + for param_key in self.model_params: + r_v_str = re.sub(r'\b' + param_key + r'\b', f'model_params["{param_key}"]', r_v_str) + rs_v_str.append(r_v_str) + return rs_v_str + + def get_reactions(self, rids, inv=False): + """create function objects for reactions rates/times within jax.numpy namespace + + extracted from kinetic laws for SBML model + + Note: reactions in the SBML model are in amounts per time (not concentration per time) + I.e. for mass balances equations, we have to devide by volumes where metabolite is located + + :param rids: reaction ids for reactions in amount per time to be processed. + :type rids: list of strings + :param inv: flag indicating if inverse of reaction required (i.e. reaction times) + :type inv: boolean (default: False) + :returns: jras + :rtype: function object, returning an 1D array of function results + """ + if inv is False: + kinetics_str = [self.xba_model.reactions[rid].expanded_kl for rid in rids] + else: + kinetics_str = [f'1.0/({self.xba_model.reactions[rid].expanded_kl})' for rid in rids] + v_str = self._to_vectors(kinetics_str) + func_code = compile('def reactions(x, model_params): ' + 'return np.array([' + ', '.join(v_str) + '])', + '<string>', 'exec') + jras = types.FunctionType(func_code.co_consts[0], {'np': jnp}) + return jras + + @staticmethod + def get_trans_idxs(smat): + """Identify metabolic reactions in S matrix with species transitions. + + I.e. columns containing both positive and negative entries + + :param smat: + :type smat: + :return: bolean vector identifying relevant columns + :rtype: 1D ndarray with dtype bool + """ + mask_trans_idxs = np.zeros(smat.shape[1]) + for col_idx in range(smat.shape[1]): + if ((np.count_nonzero(smat[:, col_idx] < 0) > 0) and + np.count_nonzero(smat[:, col_idx] > 0) > 0): + mask_trans_idxs[col_idx] = 1.0 + return mask_trans_idxs > 0 + + @staticmethod + def get_export_idxs(smat): + """Identify metabolic reactions in S matrix with consumption only. + + I.e. columns containing only negative entries + + :param smat: + :type smat: + :return: bolean vector identifying relevant columns + :rtype: 1D ndarray with dtype bool + """ + mask_remove_idxs = np.zeros(smat.shape[1]) + for col_idx in range(smat.shape[1]): + if ((np.count_nonzero(smat[:, col_idx] < 0) > 0) and + np.count_nonzero(smat[:, col_idx] > 0) == 0): + mask_remove_idxs[col_idx] = 1.0 + return mask_remove_idxs > 0 + + # @jax.jit + def get_growth_rate_0(self, x, model_params): + """Calculate growth rate at x, assuming no protein degradation + + Equations: + gr0(x) = 1/tau(x) + tau(x) = pstms(x).dot(ce) + pstms(x) = enz_vols * ptsmas(x) + + :param x: optimization variables + :type x: numpy.ndarray + :param model_params: model parameters that can be changed + :type model_params: dict + :returns: growth rate at x (without protein degradation) + :rtype: float + """ + jx = jnp.array(x) + ce = jx[self.jvar_mask_enz] + pstms = self.jenz_vols * self.jpstmas_jit(jx, model_params) + gr0 = jnp.divide(1.0, jnp.dot(pstms, ce)) + return gr0 + + # @jax.jit + def get_growth_rate(self, x, model_params): + """Calculate growth rate at x, considering protein degradation + + considering stoichiometric coefficiens unequal unity for enzyme turnover + stoichiometric sub-matrices do not require volume scaling (as Vol_enz gets factored out) + With degradation (!!negative stoichiometric coefficients in enz_x_mrs), growth rate reduces + - first calculate the enzyme amount per time being degraded + + Equations: + gr(x) = gr0(x) * (1 + enz_to_mras(x).dot(pstmas(x))) + enz_to_mras(x) = enz_x_mrs.dot(mras(x)) + + :param x: optimization variables + :type x: numpy.ndarray + :param model_params: model parameters that can be changed + :type model_params: dict + :returns: growth rate at x (considering protein degradation) + :rtype: float + """ + jx = jnp.array(x) + gr0 = self.get_growth_rate_0(jx, model_params) + pstmas = self.jpstmas_jit(jx, model_params) + enz_to_mras = jnp.dot(self.jenz_x_degrads, self.jdegradas_jit(x, model_params)) + gr = gr0 * (1.0 + jnp.dot(enz_to_mras, pstmas)) + return gr + + def get_alphas(self, x, model_params): + """Calculate ribosomal allocations at x, considering protein degradation + + alpha_e = gr * ce_e * pstm_e + p_degr_e * pstm_e + alpha = gr * ce.dot(pstms(x)) + p_degrs(x).dot(pstms(x)) + + :param x: optimization variables + :type x: numpy.ndarray + :param model_params: model parameters that can be changed + :type model_params: dict + :returns: ribosomal allocations at x (considering protein degradation) + :rtype: numpy.ndarray (1D) + """ + jx = jnp.array(x) + ce = jx[self.jvar_mask_enz] + gr = self.get_growth_rate(jx, model_params) + pstms = self.jenz_vols * self.jpstmas_jit(jx, model_params) + + alphas0 = gr * ce * pstms + pstmas = self.jpstmas_jit(jx, model_params) + enz_to_mras = jnp.dot(self.jenz_x_degrads, self.jdegradas_jit(jx, model_params)) + alphas = alphas0 - enz_to_mras * pstmas + return alphas + + def heq_mass_balance(self, x, model_params): + """Calculate mass balance equations at x for metabolites + + incuding protein degradation, which consume metabolites + Note: protein degradation is also part of metabolic reactions (mras), which return + metabolites. + + metab_mb = dcm/dt = metabolic_balance + growth_dilution + sythesis_enzyme_balance + metabolic_balance = metab_x_mras.dot(mras(x)) + growth_dilution = λ * (metab_x_psrs.dot(ce) - cm) + ce_scaled = enz_vols * ce # convert concentration to amounts, considering stoichiometry + sythesis_enzyme_balance = - metab_x_psras.dot(enz_vols * enzyme_balance) + enzyme_balance = enz_x_mras.dot(mras(x)) + + :param x: optimization variables + :type x: numpy.ndarray + :param model_params: model parameters that can be changed + :type model_params: dict + :returns: mass balance calculated at x + :rtype: numpy.ndarray (1D) with shape (len(var_mask_metab), 1) + """ + jx = jnp.array(x) + gr = self.get_growth_rate(jx, model_params) + cm = jx[self.jvar_mask_metab] + ce = jx[self.jvar_mask_enz] + + metabolic_mb = jnp.dot(self.jmetab_x_mras, self.jmras_jit(jx, model_params)) + dilution = jnp.dot(self.jmetab_x_psrs, ce) - cm + metab_mb = metabolic_mb + gr * dilution + enz_mb = jnp.dot(self.jenz_x_mras, self.jmras_jit(jx, model_params)) + metab_mb -= jnp.dot(self.jmetab_x_psrs, enz_mb) + return metab_mb + + def heq_enz_trans(self, x, model_params): + """Calculate enzyme transition equations at x + + :param x: optimization variables + :type x: numpy.ndarray + :param model_params: model parameters that can be changed + :type model_params: dict + :returns: Jacobian of mass balance equations at x + :rtype: numpy.ndarray (1D) with shape (len(subenz_x_trans.shape[1]), 1) + """ + jx = jnp.array(x) + return jnp.dot(self.jsubenz_x_trans, self.jenz_tras_jit(jx, model_params)) + + def heq_density(self, x, model_params): + """Calculate dry mass density (g/l) constraint at x + + Could be extended having several density constraints (not yet implemented) + + alternalively, we could check macromolecular (protein) density using heq_density_macro + + using ipopt, the constraint bounds could be set to rho (only add up masses) + density = sum(species_concentration * molecular_weight) + + :param x: optimization variables + :type x: numpy.ndarray + :param model_params: model parameters that can be changed + :type model_params: dict + :returns: mass balance calculated at x + :rtype: numpy.ndarray (1D) with shape (1,) for now + """ + jx = jnp.array(x) + density = jnp.dot(self.jvar_mws, jx) + return jnp.array([model_params['rho'] - density]) * self.scale_density + + def heq_density_macro(self, x, model_params): + """Calculate macromolecular density (g/l) constraint at x + + Could be extended having several density constraints (not yet implemented) + + macromolecular density is total protein mass density + + using ipopt, the constraint bounds could be set to rho (only add up masses) + density = sum(species_concentration * molecular_weight) + + :param x: optimization variables + :type x: numpy.ndarray + :param model_params: model parameters that can be changed + :type model_params: dict + :returns: mass balance calculated at x + :rtype: numpy.ndarray (1D) with shape (1,) for now + """ + jx = jnp.array(x) + ce = jx[self.jvar_mask_enz] + density = jnp.dot(self.jvar_mws[self.jvar_mask_enz], ce) + return jnp.array([model_params['rho'] - density]) * self.scale_density diff --git a/xbanalysis/problems/rba_problem.py b/xbanalysis/problems/rba_problem.py index dc0fafe..7575115 100644 --- a/xbanalysis/problems/rba_problem.py +++ b/xbanalysis/problems/rba_problem.py @@ -18,6 +18,7 @@ from scipy.sparse import coo_array, hstack, vstack from xbanalysis.solvers.glpk_linear_problem import GlpkLinearProblem +# TODO: implement protein degradation class RbaProblem: @@ -26,7 +27,6 @@ class RbaProblem: :param xba_model: """ - self.xba_model = xba_model # metabolic reactions and enzyme transitions (no FBA reactions, no protein synthesis, no degradation) @@ -36,7 +36,7 @@ class RbaProblem: self.rid2enz = {rid: xba_model.reactions[rid].enzyme for rid in self.rids} self.enz_sids = np.array([enz_sid for enz_sid in self.rid2enz.values() if enz_sid is not None]) self.enz_sids_rev = np.array([enz_sid for rid, enz_sid in self.rid2enz.items() - if enz_sid is not None and xba_model.reactions[rid].reversible]) + if enz_sid is not None and xba_model.reactions[rid].reversible]) self.enz2idx = {enz: idx for idx, enz in enumerate(self.enz_sids)} self.processes = {s.rba_process_machine: sid for sid, s in xba_model.species.items() if hasattr(s, 'rba_process_machine')} diff --git a/xbanalysis/solvers/__init__.py b/xbanalysis/solvers/__init__.py index 4e4ee2e..6d25e64 100644 --- a/xbanalysis/solvers/__init__.py +++ b/xbanalysis/solvers/__init__.py @@ -1,7 +1,8 @@ """Subpackage with XBA model classes """ from .glpk_linear_problem import GlpkLinearProblem -from .gba_ipopt_problem import GbaIpoptProblem -from .gba_sflux_ipopt_problem import GbaSfluxIpoptProblem +from .ipopt_gba_problem import IpoptGbaProblem +from .ipopt_gba_sflux_problem import IpoptGbaSfluxProblem +from .ipopt_gba_stoic_problem import IpoptGbaStoicProblem -__all__ = ['GlpkLinearProblem', 'GbaIpoptProblem', 'GbaSfluxIpoptProblem'] +__all__ = ['GlpkLinearProblem', 'IpoptGbaProblem', 'IpoptGbaSfluxProblem', 'IpoptGbaStoicProblem'] diff --git a/xbanalysis/solvers/gba_ipopt_problem.py b/xbanalysis/solvers/ipopt_gba_problem.py similarity index 95% rename from xbanalysis/solvers/gba_ipopt_problem.py rename to xbanalysis/solvers/ipopt_gba_problem.py index 958431c..fa129b5 100644 --- a/xbanalysis/solvers/gba_ipopt_problem.py +++ b/xbanalysis/solvers/ipopt_gba_problem.py @@ -1,4 +1,4 @@ -"""Implementation of GbaIpoptProblem class. +"""Implementation of IpoptGbaProblem class. Peter Schubert, HHU Duesseldorf, June 2022 """ @@ -7,7 +7,7 @@ import numpy as np # based on cyipopt examples -class GbaIpoptProblem: +class IpoptGbaProblem: def __init__(self, gba_problem): """Initialize GbaIpoptProblem. @@ -56,8 +56,8 @@ class GbaIpoptProblem: ).flatten() return jacs - def hessianstructure(self): - return np.tril_indices(self._n_vars) + # def hessianstructure(self): + # return np.tril_indices(self._n_vars) def hessian(self, x, lagrange, obj_factor): hess = obj_factor * -self.gba_problem.get_growth_rate_hess(x) diff --git a/xbanalysis/solvers/gba_sflux_ipopt_problem.py b/xbanalysis/solvers/ipopt_gba_sflux_problem.py similarity index 73% rename from xbanalysis/solvers/gba_sflux_ipopt_problem.py rename to xbanalysis/solvers/ipopt_gba_sflux_problem.py index 49718a3..e4a1f5b 100644 --- a/xbanalysis/solvers/gba_sflux_ipopt_problem.py +++ b/xbanalysis/solvers/ipopt_gba_sflux_problem.py @@ -1,4 +1,4 @@ -"""Implementation of GbaSfluxIpoptProblem class. +"""Implementation of IpoptGbaSfluxProblem class. For solving GBA problem using scaled fluxes as per Hugo Dourado @@ -6,10 +6,11 @@ Peter Schubert, HHU Duesseldorf, September 2022 """ import numpy as np +import jax.numpy as jnp # based on cyipopt examples -class GbaSfluxIpoptProblem: +class IpoptGbaSfluxProblem: def __init__(self, gba_sflux_problem, model_params=None): """Initialize sFluxIpoptProblem. @@ -33,23 +34,27 @@ class GbaSfluxIpoptProblem: self.constraint_dims = sum(gba_sflux_problem.n_constraints.values()) def objective(self, x): + jx = jnp.array(x) self.nfev += 1 - return -np.float64(self.sflux_problem.obj_mu(x, self.model_params)) + return -np.float64(self.sflux_problem.obj_mu(jx, self.model_params)) def gradient(self, x): + jx = jnp.array(x) self.njev += 1 - return -np.array(self.sflux_problem.obj_mu_grad(x, self.model_params)) + return -np.array(self.sflux_problem.obj_mu_grad(jx, self.model_params)) def constraints(self, x): + jx = jnp.array(x) # TODO: improve code - eqs = np.hstack((self.sflux_problem.constr_density(x, self.model_params), - self.sflux_problem.constrs_conc(x, self.model_params))) + eqs = np.hstack((self.sflux_problem.constr_density(jx, self.model_params), + self.sflux_problem.constrs_conc(jx, self.model_params))) return eqs def jacobian(self, x): # TODO: improve code - jacs = (np.vstack((self.sflux_problem.constr_density_grad(x, self.model_params), - self.sflux_problem.constrs_conc_jac(x, self.model_params))) + jx = jnp.array(x) + jacs = (np.vstack((self.sflux_problem.constr_density_grad(jx, self.model_params), + self.sflux_problem.constrs_conc_jac(jx, self.model_params))) ).flatten() return jacs @@ -57,9 +62,10 @@ class GbaSfluxIpoptProblem: # return np.tril_indices(self._n_vars) def hessian(self, x, lagrange, obj_factor): - hess = obj_factor * -self.sflux_problem.obj_mu_hess(x, self.model_params) - hesss = np.vstack(([self.sflux_problem.constr_density_hess(x, self.model_params)], - self.sflux_problem.constrs_conc_hesss(x, self.model_params))) + jx = jnp.array(x) + hess = obj_factor * -self.sflux_problem.obj_mu_hess(jx, self.model_params) + hesss = np.vstack(([self.sflux_problem.constr_density_hess(jx, self.model_params)], + self.sflux_problem.constrs_conc_hesss(jx, self.model_params))) hess += (lagrange.reshape((-1, 1, 1)) * hesss).sum(axis=0) return hess[np.tril_indices(self._n_vars)] diff --git a/xbanalysis/solvers/ipopt_gba_stoic_problem.py b/xbanalysis/solvers/ipopt_gba_stoic_problem.py new file mode 100644 index 0000000..1303d03 --- /dev/null +++ b/xbanalysis/solvers/ipopt_gba_stoic_problem.py @@ -0,0 +1,95 @@ +"""Implementation of IpoptGbaStoicProblem class. + +For solving GBA problem using scaled fluxes as per Hugo Dourado + +Peter Schubert, HHU Duesseldorf, September 2022 +""" + +import numpy as np +import jax.numpy as jnp + + +# based on cyipopt examples +class IpoptGbaStoicProblem: + + def __init__(self, gba_stoic_problem, model_params=None): + """Initialize sFluxIpoptProblem. + + :param gba_stoic_problem: GbaStoicProblem from where to extract relevant parameters + :type gba_stoic_problem: GbaStoicProblem + :param model_params: model parameters (optional, default None, use params from GbaStoicProblem + :type model_params: None or dict + """ + self.stoic_problem = gba_stoic_problem + if model_params is None: + self.model_params = gba_stoic_problem.model_params + else: + self.model_params = model_params + self._n_vars = len(gba_stoic_problem.var_initial) + self.report_freq = 0 + + self.nfev = 0 + self.njev = 0 + self.nit = 0 + self.constraint_dims = sum(gba_stoic_problem.n_constraints.values()) + + def objective(self, x): + jx = jnp.array(x) + self.nfev += 1 + return -np.float64(self.stoic_problem.jobj_mu_jit(jx, self.model_params)) + + def gradient(self, x): + jx = jnp.array(x) + self.njev += 1 + return -np.array(self.stoic_problem.jobj_mu_grad_jit(jx, self.model_params)) + + def constraints(self, x): + jx = jnp.array(x) + if self.stoic_problem.n_constraints['heq_enz_trans'] > 0: + eqs = np.hstack((self.stoic_problem.jeq_mass_balance_jit(jx, self.model_params), + self.stoic_problem.jeq_enz_trans_jit(jx, self.model_params), + self.stoic_problem.jeq_density_jit(jx, self.model_params))) + else: + eqs = np.hstack((self.stoic_problem.jeq_mass_balance_jit(jx, self.model_params), + self.stoic_problem.jeq_density_jit(jx, self.model_params))) + return eqs + + def jacobian(self, x): + jx = jnp.array(x) + if self.stoic_problem.n_constraints['heq_enz_trans'] > 0: + jacs = (np.vstack((self.stoic_problem.jeq_mass_balance_jac_jit(jx, self.model_params), + self.stoic_problem.jeq_enz_trans_jac_jit(jx, self.model_params), + self.stoic_problem.jeq_density_jac_jit(jx, self.model_params))) + ).flatten() + else: + jacs = (np.vstack((self.stoic_problem.jeq_mass_balance_jac_jit(jx, self.model_params), + self.stoic_problem.jeq_density_jac_jit(jx, self.model_params))) + ).flatten() + return jacs + + # def hessianstructure(self): + # return np.tril_indices(self._n_vars) + + def hessian(self, x, lagrange, obj_factor): + jx = jnp.array(x) + hess = obj_factor * -self.stoic_problem.jobj_mu_hess_jit(x, self.model_params) + if self.stoic_problem.n_constraints['heq_enz_trans'] > 0: + hesss = np.vstack((self.stoic_problem.jeq_mass_balance_hesss_jit(jx, self.model_params), + self.stoic_problem.jeq_enz_trans_hesss_jit(jx, self.model_params), + self.stoic_problem.jeq_density_hesss_jit(jx, self.model_params))) + else: + hesss = np.vstack((self.stoic_problem.jeq_mass_balance_hesss_jit(jx, self.model_params), + self.stoic_problem.jeq_density_hesss_jit(jx, self.model_params))) + hess += (lagrange.reshape((-1, 1, 1)) * hesss).sum(axis=0) + return hess[np.tril_indices(self._n_vars)] + + def set_report_freq(self, mod_val=0): + self.report_freq = mod_val + + # noinspection PyUnusedLocal + def intermediate(self, alg_mod, iter_count, obj_value, inf_pr, inf_du, mu, + d_norm, regularization_size, alpha_du, alpha_pr, ls_trials): + self.nit = iter_count + if self.report_freq > 0: + if iter_count % self.report_freq == 0: + print(f'[{iter_count:5d}] growth rate: {-obj_value:.4f} 1/h') diff --git a/xbanalysis/utils/opt_results.py b/xbanalysis/utils/opt_results.py index 5a03ca3..7468d9b 100644 --- a/xbanalysis/utils/opt_results.py +++ b/xbanalysis/utils/opt_results.py @@ -8,16 +8,19 @@ import numpy as np class OptResults: - def __init__(self, problem, n_points): + def __init__(self, problem, n_points, model_params=None): """Initialize OptResults. :param problem: GbaProblem from where to extract relevant parameters :type problem: GbaProblem + :param model_params: model parameters (used for jax based problems) + :type model_params: dict or None :param n_points: number of points to record :type n_points: integer """ self.problem = problem self._n_points = n_points + self.model_params = model_params self._idx = 0 self._n_m = len(self.problem.var_metab_sids) self._n_e = len(self.problem.var_enz_sids) @@ -48,15 +51,24 @@ class OptResults: self.ress['grs_per_h'][i] = -3600 * info['obj_val'] self.ress['cm_M'][i] = cx[self.problem.var_mask_metab] self.ress['ce_M'][i] = cx[self.problem.var_mask_enz] - rho = self.problem.model_params['mws'] * cx + rho = self.problem.var_mws * cx self.ress['rho_m'][i] = rho[self.problem.var_mask_metab] self.ress['rho_e'][i] = rho[self.problem.var_mask_enz] scale_factor = 3600 * 1e3 / self.problem.model_params['rho'] - self.ress['mr_fluxes'][i] = scale_factor * self.problem.mras(cx) / self.problem.mr_vols - self.ress['psm_fluxes'][i] = scale_factor / self.problem.pstmas(cx) / self.problem.ps_vols + if self.model_params is None: + self.ress['mr_fluxes'][i] = scale_factor * self.problem.mras(cx) / self.problem.mr_vols + self.ress['psm_fluxes'][i] = scale_factor / self.problem.pstmas(cx) / self.problem.ps_vols + + self.ress['alphas'][i] = self.problem.get_alphas(cx) + else: + self.ress['mr_fluxes'][i] = (scale_factor * self.problem.mras(cx, self.model_params) + / self.problem.mr_vols) + self.ress['psm_fluxes'][i] = (scale_factor / self.problem.pstmas(cx, self.model_params) + / self.problem.ps_vols) + + self.ress['alphas'][i] = self.problem.get_alphas(cx, self.model_params) - self.ress['alphas'][i] = self.problem.get_alphas(cx) self._idx += 1 def get(self): -- GitLab