From 3ed33387e734c7cdd498a0cc7f4a65c8d711b634 Mon Sep 17 00:00:00 2001
From: Peter Schubert <Peter.Schubert@hhu.de>
Date: Fri, 23 Sep 2022 16:24:48 +0200
Subject: [PATCH] gba_stoic_problem solved via ipopt using jax autograd and jit

---
 xbanalysis/problems/__init__.py               |   5 +-
 xbanalysis/problems/gba_problem.py            |  14 +-
 xbanalysis/problems/gba_sflux_problem.py      |   2 +-
 xbanalysis/problems/gba_stoic_problem.py      | 439 ++++++++++++++++++
 xbanalysis/problems/rba_problem.py            |   4 +-
 xbanalysis/solvers/__init__.py                |   7 +-
 ..._ipopt_problem.py => ipopt_gba_problem.py} |   8 +-
 ..._problem.py => ipopt_gba_sflux_problem.py} |  28 +-
 xbanalysis/solvers/ipopt_gba_stoic_problem.py |  95 ++++
 xbanalysis/utils/opt_results.py               |  22 +-
 10 files changed, 590 insertions(+), 34 deletions(-)
 create mode 100644 xbanalysis/problems/gba_stoic_problem.py
 rename xbanalysis/solvers/{gba_ipopt_problem.py => ipopt_gba_problem.py} (95%)
 rename xbanalysis/solvers/{gba_sflux_ipopt_problem.py => ipopt_gba_sflux_problem.py} (73%)
 create mode 100644 xbanalysis/solvers/ipopt_gba_stoic_problem.py

diff --git a/xbanalysis/problems/__init__.py b/xbanalysis/problems/__init__.py
index a43385f..7df3913 100644
--- a/xbanalysis/problems/__init__.py
+++ b/xbanalysis/problems/__init__.py
@@ -1,8 +1,9 @@
 """Subpackage with XBA model classes """
 
-from .gba_problem import GbaProblem
 from .rba_problem import RbaProblem
 from .fba_problem import FbaProblem
+from .gba_problem import GbaProblem
 from .gba_sflux_problem import GbaSfluxProblem
+from .gba_stoic_problem import GbaStoicProblem
 
-__all__ = ['GbaProblem', 'FbaProblem', 'RbaProblem', 'GbaSfluxProblem']
+__all__ = ['GbaProblem', 'FbaProblem', 'RbaProblem', 'GbaSfluxProblem', 'GbaStoicProblem']
diff --git a/xbanalysis/problems/gba_problem.py b/xbanalysis/problems/gba_problem.py
index 55d9e58..b9354bf 100644
--- a/xbanalysis/problems/gba_problem.py
+++ b/xbanalysis/problems/gba_problem.py
@@ -59,6 +59,7 @@ class GbaProblem:
         self.var_mask_metab = [True if sid in self.var_metab_sids else False for sid in self.var_sids]
         self.var_mask_enz = [True if sid in self.var_enz_sids else False for sid in self.var_sids]
         self.var_initial = np.array([xba_model.species[sid].initial_conc for sid in self.var_sids])
+        self.var_mws = np.array([xba_model.species[sid].mw for sid in self.var_sids])
         self.var_vols = np.array([xba_model.compartments[xba_model.species[sid].compartment].size
                                   for sid in self.var_sids])
 
@@ -110,7 +111,6 @@ class GbaProblem:
         self.model_params['ext_conc'] = np.array([xba_model.species[sid].initial_conc
                                                   for sid in self.const_metab_sids])
         self.model_params['np'] = np
-        self.model_params['mws'] = np.array([xba_model.species[sid].mw for sid in self.var_sids])
 
         # create functions
         self.fmras, self.fmras_jac, self.fmras_hesss = self.get_reactions(self.mr_rids, inv=False)
@@ -389,7 +389,9 @@ class GbaProblem:
 
         considering stoichiometric coefficiens unequal unity for enzyme turnover
         stoichiometric sub-matrices do not require volume scaling (as Vol_enz gets factored out)
-        With dilution (negative stoichiometric coefficients in enz_x_mrs), growth rate reduces
+        With degradation (!!negative stoichiometric coefficients in enz_x_mrs), growth rate reduces
+        - first calculate the enzyme amount per time being degraded
+
 
         Equations:
             gr(x) = gr0(x) * (1 + enz_to_mras(x).dot(pstmas(x)))
@@ -714,9 +716,9 @@ class GbaProblem:
         """
         ce = x[self.var_mask_enz]
         if self.model_params.get('macro', True) is True:
-            density = self.model_params['mws'][self.var_mask_enz].dot(ce)
+            density = self.var_mws[self.var_mask_enz].dot(ce)
         else:
-            density = self.model_params['mws'].dot(x)
+            density = self.var_mws.dot(x)
         return np.array([self.model_params['rho'] - density]) * self.scale_density
 
     @cache_last
@@ -734,9 +736,9 @@ class GbaProblem:
         """
         if self.model_params.get('macro', True) is True:
             density_grad = np.zeros(len(x))
-            density_grad[self.var_mask_enz] = -self.model_params['mws'][self.var_mask_enz]
+            density_grad[self.var_mask_enz] = -self.var_mws[self.var_mask_enz]
         else:
-            density_grad = -self.model_params['mws']
+            density_grad = -self.var_mws
         return density_grad * self.scale_density
 
     # noinspection PyUnusedLocal
diff --git a/xbanalysis/problems/gba_sflux_problem.py b/xbanalysis/problems/gba_sflux_problem.py
index 26a8b09..672fd34 100644
--- a/xbanalysis/problems/gba_sflux_problem.py
+++ b/xbanalysis/problems/gba_sflux_problem.py
@@ -54,7 +54,7 @@ class GbaSfluxProblem:
         rid_ps_ribo = xba_model.species[sid_ribo].ps_rid
         self.rids = self.mr_rids + [rid_ps_ribo]
         self.rid2idx = {mrid: idx for idx, mrid in enumerate(self.rids)}
-        self.enz_sids = np.array([self.xba_model.reactions[rid].enzyme for rid in self.rids])
+        self.enz_sids = np.array([xba_model.reactions[rid].enzyme for rid in self.rids])
 
         self.mw_sids = np.array([xba_model.species[sid].mw for sid in self.sids])
         self.mw_enz = np.array([xba_model.species[enz_sid].mw for enz_sid in self.enz_sids])
diff --git a/xbanalysis/problems/gba_stoic_problem.py b/xbanalysis/problems/gba_stoic_problem.py
new file mode 100644
index 0000000..87f55fa
--- /dev/null
+++ b/xbanalysis/problems/gba_stoic_problem.py
@@ -0,0 +1,439 @@
+"""Implementation of GBA stoic Problem class.
+
+based on gba_problem
+- extended Growth Balance Analysis (GBA) problem from Deniz Sezer
+- basic problem uses internal metabolite and enzyme concentrations
+  as free variables (optimization variables)
+  it optimizes growth rate under constraints of mass balance and
+  total density (cellular density or macromolecular density)
+- extended formulation takes care about
+  - protein degradation
+  - protein interconversion
+  - spontaneous reactions
+
+- a related formulation is based on Hugo Dourado scalled fluxes
+  see gba_sflux_problem.py
+
+- this implementation uses the JAX package to
+  - determine first and second derivatives of all functions
+  - just in time compilation for speed-up
+  - support of accellerators (e.g. GPUs) - yet to be tested
+
+Peter Schubert, HHU Duesseldorf, September 2022
+"""
+
+import numpy as np
+import re
+import types
+import scipy
+import jax
+import jax.numpy as jnp
+
+jax.config.update("jax_enable_x64", True)
+
+# TODO: variable transformation x = exp(y) and solve for unbounded y. start with vector y = log(x)
+
+
+class GbaStoicProblem:
+
+    cache = True
+    last_values = {}
+
+    def __init__(self, xba_model):
+        """Initialize GbaStoichProblem from a xba_model.
+
+        :param xba_model: xba model with all required GBA parameters
+        :type xba_model: XbaModel
+        """
+        self.xba_model = xba_model
+
+        # check to scale density constraint to improve convergence
+        self.scale_density = 1e-4
+
+        # metabolites and enzymes in the model
+        metab_sids = [s.id for s in xba_model.species.values()
+                      if s.sboterm.is_in_branch('SBO:0000247')]
+
+        # optimization variables
+        self.var_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is False]
+        self.var_enz_sids = [s.id for s in xba_model.species.values()
+                             if s.sboterm.is_in_branch('SBO:0000246')]
+        self.var_sids = self.var_metab_sids + self.var_enz_sids
+        self.n_vars = len(self.var_sids)
+        self.var_sid2idx = {sid: idx for idx, sid in enumerate(self.var_sids)}
+        self.var_mask_metab = [True if sid in self.var_metab_sids else False for sid in self.var_sids]
+        self.var_mask_enz = [True if sid in self.var_enz_sids else False for sid in self.var_sids]
+        self.var_initial = np.array([xba_model.species[sid].initial_conc for sid in self.var_sids])
+        self.var_mws = np.array([xba_model.species[sid].mw for sid in self.var_sids])
+        self.var_vols = np.array([xba_model.compartments[xba_model.species[sid].compartment].size
+                                  for sid in self.var_sids])
+
+        # constant concentrations
+        self.const_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is True]
+        self.const_sid2idx = {sid: idx for idx, sid in enumerate(self.const_metab_sids)}
+
+        # protein synthesis reactions get ordered according to the order of their enzyme products
+        #  to facilitate matrix multiplication without further remapping
+        # - here we expect a one-to-one relationship between enzyme and corresponding protein synthesis reaction
+        self.mr_rids = [r.id for r in xba_model.reactions.values()
+                        if r.sboterm.is_in_branch('SBO:0000167')]
+        self.ps_rids = [xba_model.species[sid].ps_rid for sid in self.var_enz_sids]
+        self.mr_vols = np.array([xba_model.compartments[xba_model.reactions[rid].compartment].size
+                                 for rid in self.mr_rids])
+        self.ps_vols = np.array([xba_model.compartments[xba_model.reactions[rid].compartment].size
+                                 for rid in self.ps_rids])
+
+        # calculate stoichiometric sub-matrices and scale them according to volume
+        # TODO: sparce matrix storage and calculations
+        species_x_mrs = xba_model.get_stoic_matrix(self.var_sids, self.mr_rids)
+        self.dof = species_x_mrs.shape[1] - np.linalg.matrix_rank(species_x_mrs)
+        metab_x_mrs = xba_model.get_stoic_matrix(self.var_metab_sids, self.mr_rids)
+        self.metab_x_psrs = xba_model.get_stoic_matrix(self.var_metab_sids, self.ps_rids)
+        self.enz_x_mrs = xba_model.get_stoic_matrix(self.var_enz_sids, self.mr_rids)
+        self.enz_vols = self.var_vols[self.var_mask_enz]
+        metab_vols = self.var_vols[self.var_mask_metab]
+
+        self.metab_x_mras = metab_x_mrs / metab_vols.reshape((-1, 1))
+        self.metab_x_psras = self.metab_x_psrs / metab_vols.reshape((-1, 1))
+        self.enz_x_mras = self.enz_x_mrs / self.enz_vols.reshape((-1, 1))
+
+        # identify reactions ids for enzyme degradation and corresponding submatrix
+        enz_degrad_idxs = self.get_export_idxs(self.enz_x_mrs)
+        enz_degrad_rids = np.array(self.mr_rids)[enz_degrad_idxs]
+        self.enz_x_degrads = self.enz_x_mrs[:, enz_degrad_idxs]
+
+        # identify reactions ids for enzyme transitions and corresponding submatrix
+        # for submatrix perform LU decomposition and remove empty rows from U matrix
+        enz_trans_idxs = self.get_trans_idxs(self.enz_x_mras)
+        enz_trans_rids = np.array(self.mr_rids)[enz_trans_idxs]
+        u = scipy.linalg.lu((self.enz_x_mras[:, enz_trans_idxs]))[2]
+        self.subenz_x_trans = u[abs(u).sum(axis=1) > 0]
+
+        # model parameters
+        self.model_params = {pid: p.value for pid, p in xba_model.parameters.items()}
+        self.model_params.update({cid: c.size for cid, c in xba_model.compartments.items()})
+        self.model_params['macro'] = self.model_params.get('macro', 0.0) > 0.0
+        self.model_params['ext_conc'] = np.array([xba_model.species[sid].initial_conc
+                                                  for sid in self.const_metab_sids])
+
+        # move some numpy structures to jax.numpy
+        # TODO: possibly code relevant numpy objects in jax.numpy to allow for accelleration
+        self.jvar_mask_enz = jnp.array(self.var_mask_enz)
+        self.jenz_x_mrs = jnp.array(self.enz_x_mrs)
+        self.jvar_mask_metab = jnp.array(self.var_mask_metab)
+        self.jenz_x_degrads = jnp.array(self.enz_x_degrads)
+        self.jenz_vols = jnp.array(self.enz_vols)
+        self.jsubenz_x_trans = jnp.array(self.subenz_x_trans)
+        self.jvar_mws = jnp.array(self.var_mws)
+        self.jmetab_x_mras = jnp.array(self.metab_x_mras)
+        self.jmetab_x_psrs = jnp.array(self.metab_x_psrs)
+        self.jenz_x_mras = jnp.array(self.enz_x_mras)
+        self.jmetab_x_psrs = jnp.array(self.metab_x_psrs)
+
+        # retrieve kinetic functions and compile them
+        self.mras = self.get_reactions(self.mr_rids, inv=False)
+        self.pstmas = self.get_reactions(self.ps_rids, inv=True)
+        self.jmras_jit = jax.jit(self.mras)
+        self.jpstmas_jit = jax.jit(self.pstmas)
+        self.jdegradas_jit = jax.jit(self.get_reactions(enz_degrad_rids, inv=False))
+        self.jenz_tras_jit = jax.jit(self.get_reactions(enz_trans_rids, inv=False))
+
+        # create functions used for optimization
+        self.jobj_mu_jit = jax.jit(self.get_growth_rate)
+        self.jobj_mu_grad_jit = jax.jit(jax.grad(self.jobj_mu_jit))
+        self.jobj_mu_hess_jit = jax.jit(jax.jacfwd(self.jobj_mu_grad_jit))
+
+        self.jeq_density_jit = jax.jit(self.heq_density)
+        self.jeq_density_jac_jit = jax.jit(jax.jacrev(self.jeq_density_jit))
+        self.jeq_density_hesss_jit = jax.jit(jax.jacfwd(self.jeq_density_jac_jit))
+        self.jeq_density_macro_jit = jax.jit(self.heq_density_macro)
+        self.jeq_density_macro_jac_jit = jax.jit(jax.jacrev(self.jeq_density_macro_jit))
+        self.jeq_density_macro_hesss_jit = jax.jit(jax.jacfwd(self.jeq_density_macro_jac_jit))
+
+        self.jeq_mass_balance_jit = jax.jit(self.heq_mass_balance)
+        self.jeq_mass_balance_jac_jit = jax.jit(jax.jacrev(self.jeq_mass_balance_jit))
+        self.jeq_mass_balance_hesss_jit = jax.jit(jax.jacfwd(self.jeq_mass_balance_jac_jit))
+
+        self.jeq_enz_trans_jit = jax.jit(self.heq_enz_trans)
+        self.jeq_enz_trans_jac_jit = jax.jit(jax.jacrev(self.jeq_enz_trans_jit))
+        self.jeq_enz_trans_hesss_jit = jax.jit(jax.jacfwd(self.jeq_enz_trans_jac_jit))
+
+        self.n_constraints = {'heq_mass_balance': len(self.var_metab_sids),
+                              'heq_enz_trans': self.subenz_x_trans.shape[0],
+                              'heq_density': 1}
+
+    def compile(self, x, model_params):
+        jx = jnp.array(x)
+
+        self.jmras_jit(jx, model_params)
+        self.jpstmas_jit(jx, model_params)
+        self.jdegradas_jit(jx, model_params)
+        self.jenz_tras_jit(jx, model_params)
+
+        self.jobj_mu_jit(jx, model_params)
+        self.jobj_mu_grad_jit(jx, model_params)
+        self.jobj_mu_hess_jit(jx, model_params)
+
+        self.jeq_density_jit(jx, model_params)
+        self.jeq_density_jac_jit(jx, model_params)
+        self.jeq_density_hesss_jit(jx, model_params)
+        self.jeq_density_macro_jit(jx, model_params)
+        self.jeq_density_macro_jac_jit(jx, model_params)
+        self.jeq_density_macro_hesss_jit(jx, model_params)
+
+        self.jeq_mass_balance_jit(jx, model_params)
+        self.jeq_mass_balance_jac_jit(jx, model_params)
+        self.jeq_mass_balance_hesss_jit(jx, model_params)
+
+        self.jeq_enz_trans_jit(jx, model_params)
+        self.jeq_enz_trans_jac_jit(jx, model_params)
+        self.jeq_enz_trans_hesss_jit(jx, model_params)
+
+    def update_model_params(self, key, val):
+        self.model_params[key] = val
+
+    def _to_vectors(self, rs_str):
+        """replace variables and constant metabolites by vector elements.
+
+        Variables converted to x[], external metabolites to ext_conc[].
+        access to model parameters via model_params dict
+
+        :param rs_str: reaction strings (extracted from the xba model).
+        :type rs_str: list of strings
+        :returns: rs_v_str
+        :rtype: list of reaction stings with parameters replaced by vector elements
+        """
+        rs_v_str = []
+        for r_v_str in rs_str:
+            for idx, sid in enumerate(self.var_sids):
+                r_v_str = re.sub(r'\b' + sid + r'\b', f'x[{idx}]', r_v_str)
+            for idx, sid in enumerate(self.const_metab_sids):
+                r_v_str = re.sub(r'\b' + sid + r'\b', f'ext_conc[{idx}]', r_v_str)
+            for param_key in self.model_params:
+                r_v_str = re.sub(r'\b' + param_key + r'\b', f'model_params["{param_key}"]', r_v_str)
+            rs_v_str.append(r_v_str)
+        return rs_v_str
+
+    def get_reactions(self, rids, inv=False):
+        """create function objects for reactions rates/times within jax.numpy namespace
+
+        extracted from kinetic laws for SBML model
+
+        Note: reactions in the SBML model are in amounts per time (not concentration per time)
+        I.e. for mass balances equations, we have to devide by volumes where metabolite is located
+
+        :param rids: reaction ids for reactions in amount per time to be processed.
+        :type rids: list of strings
+        :param inv: flag indicating if inverse of reaction required (i.e. reaction times)
+        :type inv: boolean (default: False)
+        :returns: jras
+        :rtype: function object, returning an 1D array of function results
+        """
+        if inv is False:
+            kinetics_str = [self.xba_model.reactions[rid].expanded_kl for rid in rids]
+        else:
+            kinetics_str = [f'1.0/({self.xba_model.reactions[rid].expanded_kl})' for rid in rids]
+        v_str = self._to_vectors(kinetics_str)
+        func_code = compile('def reactions(x, model_params): '
+                            'return np.array([' + ', '.join(v_str) + '])',
+                            '<string>', 'exec')
+        jras = types.FunctionType(func_code.co_consts[0], {'np': jnp})
+        return jras
+
+    @staticmethod
+    def get_trans_idxs(smat):
+        """Identify metabolic reactions in S matrix with species transitions.
+
+        I.e. columns containing both positive and negative entries
+
+        :param smat:
+        :type smat:
+        :return: bolean vector identifying relevant columns
+        :rtype: 1D ndarray with dtype bool
+        """
+        mask_trans_idxs = np.zeros(smat.shape[1])
+        for col_idx in range(smat.shape[1]):
+            if ((np.count_nonzero(smat[:, col_idx] < 0) > 0) and
+                    np.count_nonzero(smat[:, col_idx] > 0) > 0):
+                mask_trans_idxs[col_idx] = 1.0
+        return mask_trans_idxs > 0
+
+    @staticmethod
+    def get_export_idxs(smat):
+        """Identify metabolic reactions in S matrix with consumption only.
+
+        I.e. columns containing only negative entries
+
+        :param smat:
+        :type smat:
+        :return: bolean vector identifying relevant columns
+        :rtype: 1D ndarray with dtype bool
+        """
+        mask_remove_idxs = np.zeros(smat.shape[1])
+        for col_idx in range(smat.shape[1]):
+            if ((np.count_nonzero(smat[:, col_idx] < 0) > 0) and
+                    np.count_nonzero(smat[:, col_idx] > 0) == 0):
+                mask_remove_idxs[col_idx] = 1.0
+        return mask_remove_idxs > 0
+
+    # @jax.jit
+    def get_growth_rate_0(self, x, model_params):
+        """Calculate growth rate at x, assuming no protein degradation
+
+        Equations:
+            gr0(x) = 1/tau(x)
+            tau(x) = pstms(x).dot(ce)
+            pstms(x) = enz_vols * ptsmas(x)
+
+        :param x: optimization variables
+        :type x: numpy.ndarray
+        :param model_params: model parameters that can be changed
+        :type model_params: dict
+        :returns: growth rate at x (without protein degradation)
+        :rtype: float
+        """
+        jx = jnp.array(x)
+        ce = jx[self.jvar_mask_enz]
+        pstms = self.jenz_vols * self.jpstmas_jit(jx, model_params)
+        gr0 = jnp.divide(1.0, jnp.dot(pstms, ce))
+        return gr0
+
+    # @jax.jit
+    def get_growth_rate(self, x, model_params):
+        """Calculate growth rate at x, considering protein degradation
+
+        considering stoichiometric coefficiens unequal unity for enzyme turnover
+        stoichiometric sub-matrices do not require volume scaling (as Vol_enz gets factored out)
+        With degradation (!!negative stoichiometric coefficients in enz_x_mrs), growth rate reduces
+        - first calculate the enzyme amount per time being degraded
+
+        Equations:
+            gr(x) = gr0(x) * (1 + enz_to_mras(x).dot(pstmas(x)))
+            enz_to_mras(x) = enz_x_mrs.dot(mras(x))
+
+        :param x: optimization variables
+        :type x: numpy.ndarray
+        :param model_params: model parameters that can be changed
+        :type model_params: dict
+        :returns: growth rate at x (considering protein degradation)
+        :rtype: float
+        """
+        jx = jnp.array(x)
+        gr0 = self.get_growth_rate_0(jx, model_params)
+        pstmas = self.jpstmas_jit(jx, model_params)
+        enz_to_mras = jnp.dot(self.jenz_x_degrads, self.jdegradas_jit(x, model_params))
+        gr = gr0 * (1.0 + jnp.dot(enz_to_mras, pstmas))
+        return gr
+
+    def get_alphas(self, x, model_params):
+        """Calculate ribosomal allocations at x, considering protein degradation
+
+        alpha_e = gr * ce_e * pstm_e + p_degr_e * pstm_e
+        alpha = gr * ce.dot(pstms(x)) + p_degrs(x).dot(pstms(x))
+
+        :param x: optimization variables
+        :type x: numpy.ndarray
+        :param model_params: model parameters that can be changed
+        :type model_params: dict
+        :returns: ribosomal allocations at x (considering protein degradation)
+        :rtype: numpy.ndarray (1D)
+        """
+        jx = jnp.array(x)
+        ce = jx[self.jvar_mask_enz]
+        gr = self.get_growth_rate(jx, model_params)
+        pstms = self.jenz_vols * self.jpstmas_jit(jx, model_params)
+
+        alphas0 = gr * ce * pstms
+        pstmas = self.jpstmas_jit(jx, model_params)
+        enz_to_mras = jnp.dot(self.jenz_x_degrads, self.jdegradas_jit(jx, model_params))
+        alphas = alphas0 - enz_to_mras * pstmas
+        return alphas
+
+    def heq_mass_balance(self, x, model_params):
+        """Calculate mass balance equations at x for metabolites
+
+        incuding protein degradation, which consume metabolites
+        Note: protein degradation is also part of metabolic reactions (mras), which return
+              metabolites.
+
+        metab_mb = dcm/dt = metabolic_balance + growth_dilution + sythesis_enzyme_balance
+          metabolic_balance = metab_x_mras.dot(mras(x))
+          growth_dilution = λ * (metab_x_psrs.dot(ce) - cm)
+            ce_scaled = enz_vols * ce   # convert concentration to amounts, considering stoichiometry
+          sythesis_enzyme_balance = - metab_x_psras.dot(enz_vols * enzyme_balance)
+            enzyme_balance = enz_x_mras.dot(mras(x))
+
+        :param x: optimization variables
+        :type x: numpy.ndarray
+        :param model_params: model parameters that can be changed
+        :type model_params: dict
+        :returns: mass balance calculated at x
+        :rtype: numpy.ndarray (1D) with shape (len(var_mask_metab), 1)
+        """
+        jx = jnp.array(x)
+        gr = self.get_growth_rate(jx, model_params)
+        cm = jx[self.jvar_mask_metab]
+        ce = jx[self.jvar_mask_enz]
+
+        metabolic_mb = jnp.dot(self.jmetab_x_mras, self.jmras_jit(jx, model_params))
+        dilution = jnp.dot(self.jmetab_x_psrs, ce) - cm
+        metab_mb = metabolic_mb + gr * dilution
+        enz_mb = jnp.dot(self.jenz_x_mras, self.jmras_jit(jx, model_params))
+        metab_mb -= jnp.dot(self.jmetab_x_psrs, enz_mb)
+        return metab_mb
+
+    def heq_enz_trans(self, x, model_params):
+        """Calculate enzyme transition equations at x
+
+        :param x: optimization variables
+        :type x: numpy.ndarray
+        :param model_params: model parameters that can be changed
+        :type model_params: dict
+        :returns: Jacobian of mass balance equations at x
+        :rtype: numpy.ndarray (1D) with shape (len(subenz_x_trans.shape[1]), 1)
+        """
+        jx = jnp.array(x)
+        return jnp.dot(self.jsubenz_x_trans, self.jenz_tras_jit(jx, model_params))
+
+    def heq_density(self, x, model_params):
+        """Calculate dry mass density (g/l) constraint at x
+
+        Could be extended having several density constraints (not yet implemented)
+
+        alternalively, we could check macromolecular (protein) density using heq_density_macro
+
+        using ipopt, the constraint bounds could be set to rho (only add up masses)
+        density = sum(species_concentration * molecular_weight)
+
+        :param x: optimization variables
+        :type x: numpy.ndarray
+        :param model_params: model parameters that can be changed
+        :type model_params: dict
+        :returns: mass balance calculated at x
+        :rtype: numpy.ndarray (1D) with shape (1,) for now
+        """
+        jx = jnp.array(x)
+        density = jnp.dot(self.jvar_mws, jx)
+        return jnp.array([model_params['rho'] - density]) * self.scale_density
+
+    def heq_density_macro(self, x, model_params):
+        """Calculate macromolecular density (g/l) constraint at x
+
+        Could be extended having several density constraints (not yet implemented)
+
+        macromolecular density is total protein mass density
+
+        using ipopt, the constraint bounds could be set to rho (only add up masses)
+        density = sum(species_concentration * molecular_weight)
+
+        :param x: optimization variables
+        :type x: numpy.ndarray
+        :param model_params: model parameters that can be changed
+        :type model_params: dict
+        :returns: mass balance calculated at x
+        :rtype: numpy.ndarray (1D) with shape (1,) for now
+        """
+        jx = jnp.array(x)
+        ce = jx[self.jvar_mask_enz]
+        density = jnp.dot(self.jvar_mws[self.jvar_mask_enz], ce)
+        return jnp.array([model_params['rho'] - density]) * self.scale_density
diff --git a/xbanalysis/problems/rba_problem.py b/xbanalysis/problems/rba_problem.py
index dc0fafe..7575115 100644
--- a/xbanalysis/problems/rba_problem.py
+++ b/xbanalysis/problems/rba_problem.py
@@ -18,6 +18,7 @@ from scipy.sparse import coo_array, hstack, vstack
 
 from xbanalysis.solvers.glpk_linear_problem import GlpkLinearProblem
 
+# TODO: implement protein degradation
 
 class RbaProblem:
 
@@ -26,7 +27,6 @@ class RbaProblem:
 
         :param xba_model:
         """
-
         self.xba_model = xba_model
 
         # metabolic reactions and enzyme transitions (no FBA reactions, no protein synthesis, no degradation)
@@ -36,7 +36,7 @@ class RbaProblem:
         self.rid2enz = {rid: xba_model.reactions[rid].enzyme for rid in self.rids}
         self.enz_sids = np.array([enz_sid for enz_sid in self.rid2enz.values() if enz_sid is not None])
         self.enz_sids_rev = np.array([enz_sid for rid, enz_sid in self.rid2enz.items()
-                                      if enz_sid is not None and  xba_model.reactions[rid].reversible])
+                                      if enz_sid is not None and xba_model.reactions[rid].reversible])
         self.enz2idx = {enz: idx for idx, enz in enumerate(self.enz_sids)}
         self.processes = {s.rba_process_machine: sid for sid, s in xba_model.species.items()
                           if hasattr(s, 'rba_process_machine')}
diff --git a/xbanalysis/solvers/__init__.py b/xbanalysis/solvers/__init__.py
index 4e4ee2e..6d25e64 100644
--- a/xbanalysis/solvers/__init__.py
+++ b/xbanalysis/solvers/__init__.py
@@ -1,7 +1,8 @@
 """Subpackage with XBA model classes """
 
 from .glpk_linear_problem import GlpkLinearProblem
-from .gba_ipopt_problem import GbaIpoptProblem
-from .gba_sflux_ipopt_problem import GbaSfluxIpoptProblem
+from .ipopt_gba_problem import IpoptGbaProblem
+from .ipopt_gba_sflux_problem import IpoptGbaSfluxProblem
+from .ipopt_gba_stoic_problem import IpoptGbaStoicProblem
 
-__all__ = ['GlpkLinearProblem', 'GbaIpoptProblem', 'GbaSfluxIpoptProblem']
+__all__ = ['GlpkLinearProblem', 'IpoptGbaProblem', 'IpoptGbaSfluxProblem', 'IpoptGbaStoicProblem']
diff --git a/xbanalysis/solvers/gba_ipopt_problem.py b/xbanalysis/solvers/ipopt_gba_problem.py
similarity index 95%
rename from xbanalysis/solvers/gba_ipopt_problem.py
rename to xbanalysis/solvers/ipopt_gba_problem.py
index 958431c..fa129b5 100644
--- a/xbanalysis/solvers/gba_ipopt_problem.py
+++ b/xbanalysis/solvers/ipopt_gba_problem.py
@@ -1,4 +1,4 @@
-"""Implementation of GbaIpoptProblem class.
+"""Implementation of IpoptGbaProblem class.
 
 Peter Schubert, HHU Duesseldorf, June 2022
 """
@@ -7,7 +7,7 @@ import numpy as np
 
 
 # based on cyipopt examples
-class GbaIpoptProblem:
+class IpoptGbaProblem:
 
     def __init__(self, gba_problem):
         """Initialize GbaIpoptProblem.
@@ -56,8 +56,8 @@ class GbaIpoptProblem:
                     ).flatten()
         return jacs
 
-    def hessianstructure(self):
-        return np.tril_indices(self._n_vars)
+    # def hessianstructure(self):
+    #   return np.tril_indices(self._n_vars)
 
     def hessian(self, x, lagrange, obj_factor):
         hess = obj_factor * -self.gba_problem.get_growth_rate_hess(x)
diff --git a/xbanalysis/solvers/gba_sflux_ipopt_problem.py b/xbanalysis/solvers/ipopt_gba_sflux_problem.py
similarity index 73%
rename from xbanalysis/solvers/gba_sflux_ipopt_problem.py
rename to xbanalysis/solvers/ipopt_gba_sflux_problem.py
index 49718a3..e4a1f5b 100644
--- a/xbanalysis/solvers/gba_sflux_ipopt_problem.py
+++ b/xbanalysis/solvers/ipopt_gba_sflux_problem.py
@@ -1,4 +1,4 @@
-"""Implementation of GbaSfluxIpoptProblem class.
+"""Implementation of IpoptGbaSfluxProblem class.
 
 For solving GBA problem using scaled fluxes as per Hugo Dourado
 
@@ -6,10 +6,11 @@ Peter Schubert, HHU Duesseldorf, September 2022
 """
 
 import numpy as np
+import jax.numpy as jnp
 
 
 # based on cyipopt examples
-class GbaSfluxIpoptProblem:
+class IpoptGbaSfluxProblem:
 
     def __init__(self, gba_sflux_problem, model_params=None):
         """Initialize sFluxIpoptProblem.
@@ -33,23 +34,27 @@ class GbaSfluxIpoptProblem:
         self.constraint_dims = sum(gba_sflux_problem.n_constraints.values())
 
     def objective(self, x):
+        jx = jnp.array(x)
         self.nfev += 1
-        return -np.float64(self.sflux_problem.obj_mu(x, self.model_params))
+        return -np.float64(self.sflux_problem.obj_mu(jx, self.model_params))
 
     def gradient(self, x):
+        jx = jnp.array(x)
         self.njev += 1
-        return -np.array(self.sflux_problem.obj_mu_grad(x, self.model_params))
+        return -np.array(self.sflux_problem.obj_mu_grad(jx, self.model_params))
 
     def constraints(self, x):
+        jx = jnp.array(x)
         # TODO: improve code
-        eqs = np.hstack((self.sflux_problem.constr_density(x, self.model_params),
-                         self.sflux_problem.constrs_conc(x, self.model_params)))
+        eqs = np.hstack((self.sflux_problem.constr_density(jx, self.model_params),
+                         self.sflux_problem.constrs_conc(jx, self.model_params)))
         return eqs
 
     def jacobian(self, x):
         # TODO: improve code
-        jacs = (np.vstack((self.sflux_problem.constr_density_grad(x, self.model_params),
-                           self.sflux_problem.constrs_conc_jac(x, self.model_params)))
+        jx = jnp.array(x)
+        jacs = (np.vstack((self.sflux_problem.constr_density_grad(jx, self.model_params),
+                           self.sflux_problem.constrs_conc_jac(jx, self.model_params)))
                 ).flatten()
         return jacs
 
@@ -57,9 +62,10 @@ class GbaSfluxIpoptProblem:
     #    return np.tril_indices(self._n_vars)
 
     def hessian(self, x, lagrange, obj_factor):
-        hess = obj_factor * -self.sflux_problem.obj_mu_hess(x, self.model_params)
-        hesss = np.vstack(([self.sflux_problem.constr_density_hess(x, self.model_params)],
-                           self.sflux_problem.constrs_conc_hesss(x, self.model_params)))
+        jx = jnp.array(x)
+        hess = obj_factor * -self.sflux_problem.obj_mu_hess(jx, self.model_params)
+        hesss = np.vstack(([self.sflux_problem.constr_density_hess(jx, self.model_params)],
+                           self.sflux_problem.constrs_conc_hesss(jx, self.model_params)))
         hess += (lagrange.reshape((-1, 1, 1)) * hesss).sum(axis=0)
         return hess[np.tril_indices(self._n_vars)]
 
diff --git a/xbanalysis/solvers/ipopt_gba_stoic_problem.py b/xbanalysis/solvers/ipopt_gba_stoic_problem.py
new file mode 100644
index 0000000..1303d03
--- /dev/null
+++ b/xbanalysis/solvers/ipopt_gba_stoic_problem.py
@@ -0,0 +1,95 @@
+"""Implementation of IpoptGbaStoicProblem class.
+
+For solving GBA problem using scaled fluxes as per Hugo Dourado
+
+Peter Schubert, HHU Duesseldorf, September 2022
+"""
+
+import numpy as np
+import jax.numpy as jnp
+
+
+# based on cyipopt examples
+class IpoptGbaStoicProblem:
+
+    def __init__(self, gba_stoic_problem, model_params=None):
+        """Initialize sFluxIpoptProblem.
+
+        :param gba_stoic_problem: GbaStoicProblem from where to extract relevant parameters
+        :type gba_stoic_problem: GbaStoicProblem
+        :param model_params: model parameters (optional, default None, use params from GbaStoicProblem
+        :type model_params: None or dict
+        """
+        self.stoic_problem = gba_stoic_problem
+        if model_params is None:
+            self.model_params = gba_stoic_problem.model_params
+        else:
+            self.model_params = model_params
+        self._n_vars = len(gba_stoic_problem.var_initial)
+        self.report_freq = 0
+
+        self.nfev = 0
+        self.njev = 0
+        self.nit = 0
+        self.constraint_dims = sum(gba_stoic_problem.n_constraints.values())
+
+    def objective(self, x):
+        jx = jnp.array(x)
+        self.nfev += 1
+        return -np.float64(self.stoic_problem.jobj_mu_jit(jx, self.model_params))
+
+    def gradient(self, x):
+        jx = jnp.array(x)
+        self.njev += 1
+        return -np.array(self.stoic_problem.jobj_mu_grad_jit(jx, self.model_params))
+
+    def constraints(self, x):
+        jx = jnp.array(x)
+        if self.stoic_problem.n_constraints['heq_enz_trans'] > 0:
+            eqs = np.hstack((self.stoic_problem.jeq_mass_balance_jit(jx, self.model_params),
+                             self.stoic_problem.jeq_enz_trans_jit(jx, self.model_params),
+                             self.stoic_problem.jeq_density_jit(jx, self.model_params)))
+        else:
+            eqs = np.hstack((self.stoic_problem.jeq_mass_balance_jit(jx, self.model_params),
+                             self.stoic_problem.jeq_density_jit(jx, self.model_params)))
+        return eqs
+
+    def jacobian(self, x):
+        jx = jnp.array(x)
+        if self.stoic_problem.n_constraints['heq_enz_trans'] > 0:
+            jacs = (np.vstack((self.stoic_problem.jeq_mass_balance_jac_jit(jx, self.model_params),
+                               self.stoic_problem.jeq_enz_trans_jac_jit(jx, self.model_params),
+                               self.stoic_problem.jeq_density_jac_jit(jx, self.model_params)))
+                    ).flatten()
+        else:
+            jacs = (np.vstack((self.stoic_problem.jeq_mass_balance_jac_jit(jx, self.model_params),
+                               self.stoic_problem.jeq_density_jac_jit(jx, self.model_params)))
+                    ).flatten()
+        return jacs
+
+    # def hessianstructure(self):
+    #    return np.tril_indices(self._n_vars)
+
+    def hessian(self, x, lagrange, obj_factor):
+        jx = jnp.array(x)
+        hess = obj_factor * -self.stoic_problem.jobj_mu_hess_jit(x, self.model_params)
+        if self.stoic_problem.n_constraints['heq_enz_trans'] > 0:
+            hesss = np.vstack((self.stoic_problem.jeq_mass_balance_hesss_jit(jx, self.model_params),
+                               self.stoic_problem.jeq_enz_trans_hesss_jit(jx, self.model_params),
+                               self.stoic_problem.jeq_density_hesss_jit(jx, self.model_params)))
+        else:
+            hesss = np.vstack((self.stoic_problem.jeq_mass_balance_hesss_jit(jx, self.model_params),
+                               self.stoic_problem.jeq_density_hesss_jit(jx, self.model_params)))
+        hess += (lagrange.reshape((-1, 1, 1)) * hesss).sum(axis=0)
+        return hess[np.tril_indices(self._n_vars)]
+
+    def set_report_freq(self, mod_val=0):
+        self.report_freq = mod_val
+
+    # noinspection PyUnusedLocal
+    def intermediate(self, alg_mod, iter_count, obj_value, inf_pr, inf_du, mu,
+                     d_norm, regularization_size, alpha_du, alpha_pr, ls_trials):
+        self.nit = iter_count
+        if self.report_freq > 0:
+            if iter_count % self.report_freq == 0:
+                print(f'[{iter_count:5d}] growth rate: {-obj_value:.4f} 1/h')
diff --git a/xbanalysis/utils/opt_results.py b/xbanalysis/utils/opt_results.py
index 5a03ca3..7468d9b 100644
--- a/xbanalysis/utils/opt_results.py
+++ b/xbanalysis/utils/opt_results.py
@@ -8,16 +8,19 @@ import numpy as np
 
 class OptResults:
 
-    def __init__(self, problem, n_points):
+    def __init__(self, problem, n_points, model_params=None):
         """Initialize OptResults.
 
         :param problem: GbaProblem from where to extract relevant parameters
         :type problem: GbaProblem
+        :param model_params: model parameters (used for jax based problems)
+        :type model_params: dict or None
         :param n_points: number of points to record
         :type n_points: integer
         """
         self.problem = problem
         self._n_points = n_points
+        self.model_params = model_params
         self._idx = 0
         self._n_m = len(self.problem.var_metab_sids)
         self._n_e = len(self.problem.var_enz_sids)
@@ -48,15 +51,24 @@ class OptResults:
             self.ress['grs_per_h'][i] = -3600 * info['obj_val']
             self.ress['cm_M'][i] = cx[self.problem.var_mask_metab]
             self.ress['ce_M'][i] = cx[self.problem.var_mask_enz]
-            rho = self.problem.model_params['mws'] * cx
+            rho = self.problem.var_mws * cx
             self.ress['rho_m'][i] = rho[self.problem.var_mask_metab]
             self.ress['rho_e'][i] = rho[self.problem.var_mask_enz]
 
             scale_factor = 3600 * 1e3 / self.problem.model_params['rho']
-            self.ress['mr_fluxes'][i] = scale_factor * self.problem.mras(cx) / self.problem.mr_vols
-            self.ress['psm_fluxes'][i] = scale_factor / self.problem.pstmas(cx) / self.problem.ps_vols
+            if self.model_params is None:
+                self.ress['mr_fluxes'][i] = scale_factor * self.problem.mras(cx) / self.problem.mr_vols
+                self.ress['psm_fluxes'][i] = scale_factor / self.problem.pstmas(cx) / self.problem.ps_vols
+
+                self.ress['alphas'][i] = self.problem.get_alphas(cx)
+            else:
+                self.ress['mr_fluxes'][i] = (scale_factor * self.problem.mras(cx, self.model_params)
+                                             / self.problem.mr_vols)
+                self.ress['psm_fluxes'][i] = (scale_factor / self.problem.pstmas(cx, self.model_params)
+                                              / self.problem.ps_vols)
+
+                self.ress['alphas'][i] = self.problem.get_alphas(cx, self.model_params)
 
-            self.ress['alphas'][i] = self.problem.get_alphas(cx)
             self._idx += 1
 
     def get(self):
-- 
GitLab