diff --git a/setup.py b/setup.py
index 7bf9b347f6ea02f9c79d20c0e0393ea196bcd0fd..e7b4b5380c955d22e50707cc10453baaac103f98 100755
--- a/setup.py
+++ b/setup.py
@@ -40,8 +40,9 @@ setup(
                       'sympy >= 1.9.0',
                       'scipy >= 1.8.1',
                       'swiglpk >= 5.0.5',
-                      'sbmlxdf >= 0.2.7'],
-    python_requires=">=3.7",
+                      'sbmlxdf >= 0.2.7',
+                      'jax >= 0.3.13'],
+    python_requires=">=3.9",
     keywords=['modeling', 'standardization', 'SBML'],
     include_package_data=True,
     **setup_kwargs
diff --git a/xbanalysis/problems/gba_problem.py b/xbanalysis/problems/gba_problem.py
index dee8097ba318e7c0c435f57451aa4f4390fb530c..55d9e58a6bb301632edcdb6eba296f040fd5296f 100644
--- a/xbanalysis/problems/gba_problem.py
+++ b/xbanalysis/problems/gba_problem.py
@@ -46,12 +46,13 @@ class GbaProblem:
         self.scale_density = 1e-4
 
         # metabolites and enzymes in the model
-        metab_sids = [s.id for s in xba_model.species.values() if s.sboterm.is_in_branch('SBO:0000247')]
-        enz_sids = [s.id for s in xba_model.species.values() if s.sboterm.is_in_branch('SBO:0000246')]
+        metab_sids = [s.id for s in xba_model.species.values()
+                      if s.sboterm.is_in_branch('SBO:0000247')]
 
         # optimization variables
         self.var_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is False]
-        self.var_enz_sids = enz_sids
+        self.var_enz_sids = [s.id for s in xba_model.species.values()
+                             if s.sboterm.is_in_branch('SBO:0000246')]
         self.var_sids = self.var_metab_sids + self.var_enz_sids
         self.n_vars = len(self.var_sids)
         self.var_sid2idx = {sid: idx for idx, sid in enumerate(self.var_sids)}
diff --git a/xbanalysis/problems/gba_sflux_problem.py b/xbanalysis/problems/gba_sflux_problem.py
index a8f985728fb1a09086d05b41241aaa5b72405eda..26a8b09957f013f31ca55260f6e94db94b5a88f4 100644
--- a/xbanalysis/problems/gba_sflux_problem.py
+++ b/xbanalysis/problems/gba_sflux_problem.py
@@ -17,10 +17,16 @@ import pandas as pd
 import re
 import sympy as sp
 import sbmlxdf
+import scipy
+import jax
+import types
+import jax.numpy as jnp
 
 from xbanalysis.model.xba_unit_def import XbaUnit
 from xbanalysis.utils.utils import expand_kineticlaw
 
+jax.config.update("jax_enable_x64", True)
+
 
 class GbaSfluxProblem:
 
@@ -37,7 +43,8 @@ class GbaSfluxProblem:
         metab_sids = [s.id for s in xba_model.species.values() if s.sboterm.is_in_branch('SBO:0000247')]
         self.const_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is True]
         self.var_metab_sids = [sid for sid in metab_sids if xba_model.species[sid].constant is False]
-        sid_ribo = [s.id for s in xba_model.species.values() if s.sboterm.is_in_branch('SBO:0000250')][0]
+        sid_ribo = [s.id for s in xba_model.species.values()
+                    if s.sboterm.is_in_branch('SBO:0000250')][0]
         self.var_sids = self.var_metab_sids + [sid_ribo]
         self.sids = self.const_metab_sids + self.var_sids
         self.sid2idx = {msid: idx for idx, msid in enumerate(self.sids)}
@@ -47,37 +54,21 @@ class GbaSfluxProblem:
         rid_ps_ribo = xba_model.species[sid_ribo].ps_rid
         self.rids = self.mr_rids + [rid_ps_ribo]
         self.rid2idx = {mrid: idx for idx, mrid in enumerate(self.rids)}
+        self.enz_sids = np.array([self.xba_model.reactions[rid].enzyme for rid in self.rids])
 
         self.mw_sids = np.array([xba_model.species[sid].mw for sid in self.sids])
-        self.mw_enz = np.array([xba_model.species[xba_model.reactions[rid].enzyme].mw
-                                for rid in self.rids])
-        # mw_enz is used for scaling of kcats, weight for protein synthesis needs to be rescaled
-        s = xba_model.species[sid_ribo]
-        if hasattr(s, 'rba_process_costs'):
-            p_len = s.rba_process_costs['translation']
-        else:
-            r = xba_model.reactions[rid_ps_ribo]
-            p_len = min(r.reactants.values())
-        self.mw_enz[-1] *= p_len
+        self.mw_enz = np.array([xba_model.species[enz_sid].mw for enz_sid in self.enz_sids])
 
-        # retrieve S-matrix: internal metabos + total protein and metabolic reaction + ribosome synthesis
+        # retrieve S-matrix and calulate mass turnover in gram reactants/products per mol reactions
         sids_x_rids = xba_model.get_stoic_matrix(self.sids, self.rids)
+        mass_turnover = self.get_mass_turnover(sids_x_rids)
+        self.rt_conv_factor = 3600.0 * mass_turnover / self.mw_enz
 
-        # calulate mass turnover in gram reactants/products per mol reactions (g/mol)
-        sids_x_rids_react = np.where(sids_x_rids > 0, 0, sids_x_rids)
-        sids_x_rids_prod = np.where(sids_x_rids < 0, 0, sids_x_rids)
-        mass_turnover_react = -sids_x_rids_react.T.dot(self.mw_sids)
-        mass_turnover_prod = sids_x_rids_prod.T.dot(self.mw_sids)
-        self.mass_turnover = np.max((mass_turnover_react, mass_turnover_prod), axis=0)
-
-        # mass normalized stoichiometric matrix
-        self.sidx_x_rids_mn = (sids_x_rids * self.mw_sids.reshape(-1, 1)) / self.mass_turnover
-
-        self.rho = xba_model.parameters['rho'].value
-        self.ext_conc_mn = np.array([xba_model.species[sid].initial_conc * xba_model.species[sid].mw
-                                     for sid in self.const_metab_sids])
-        self.var_initial_mn = np.array([xba_model.species[sid].initial_conc * xba_model.species[sid].mw
-                                        for sid in self.var_metab_sids])
+        # mass normalized stoichiometric matrices with/without external metabolites
+        self.sids_x_rids_mn = (sids_x_rids * self.mw_sids.reshape(-1, 1)) / mass_turnover
+        self.var_sids_x_rids_mn = self.sids_x_rids_mn[-len(self.var_sids):, ]
+        self.colsum_var_sids_x_rids_mn = np.sum(self.var_sids_x_rids_mn, axis=0)
+        self.dof = self.var_sids_x_rids_mn.shape[1] - np.linalg.matrix_rank(self.var_sids_x_rids_mn)
 
         # identify unit ids for kinetic parameters
         units_per_s = [XbaUnit({'kind': 'second', 'exp': -1.0, 'scale': 0, 'mult': 1.0})]
@@ -90,6 +81,125 @@ class GbaSfluxProblem:
         self.kms_global = {pid: p.value for pid, p in xba_model.parameters.items()
                            if p.units == self.km_udid}
 
+        # model parameters
+        self.model_params = {pid: p.value for pid, p in xba_model.parameters.items()}
+        self.model_params['ext_conc_mn'] = np.array([xba_model.species[sid].initial_conc *
+                                                     xba_model.species[sid].mw
+                                                     for sid in self.const_metab_sids])
+        self.model_params['var_initial_mn'] = np.array([xba_model.species[sid].initial_conc *
+                                                        xba_model.species[sid].mw
+                                                        for sid in self.var_metab_sids])
+
+        # create functions used for optimization
+        self.jrts = self.get_reaction_times(self.rids)
+        self.obj_mu = jax.jit(self.jmu)
+        self.obj_mu_grad = jax.jit(jax.grad(self.jmu))
+        self.obj_mu_hess = jax.jit(jax.jacrev(jax.jacfwd(self.jmu)))
+        self.constr_density = jax.jit(self.jg)
+        self.constr_density_grad = jax.jit(jax.grad(self.jg))
+        self.constr_density_hess = jax.jit(jax.jacrev(jax.jacfwd(self.jg)))
+        self.constrs_conc = jax.jit(self.jconcentrations)
+        self.constrs_conc_jac = jax.jit(jax.jacfwd(self.jconcentrations))
+        self.constrs_conc_hesss = jax.jit(jax.jacrev(jax.jacfwd(self.jconcentrations)))
+
+        self.n_constraints = {'constr_density': 1,
+                              'constrs_conc': len(self.var_sids) + len(self.enz_sids),
+                              }
+
+    def recompile(self):
+        # recompiliation of function is necessary after parameters have been changed
+        self.jrts = self.get_reaction_times(self.rids)
+        self.obj_mu = jax.jit(self.jmu)
+        self.obj_mu_grad = jax.jit(jax.grad(self.jmu))
+        self.obj_mu_hess = jax.jit(jax.jacrev(jax.jacfwd(self.jmu)))
+        self.constr_density = jax.jit(self.jg)
+        self.constr_density_grad = jax.jit(jax.grad(self.jg))
+        self.constr_density_hess = jax.jit(jax.jacrev(jax.jacfwd(self.jg)))
+        self.constrs_conc = jax.jit(self.jconcentrations)
+        self.constrs_conc_jac = jax.jit(jax.jacfwd(self.jconcentrations))
+        self.constrs_conc_hesss = jax.jit(jax.jacrev(jax.jacfwd(self.jconcentrations)))
+
+    def get_w0_initial(self):
+        p_initial_mn = self.model_params['rho'] - sum(self.model_params['var_initial_mn'])
+        initial_mn = np.hstack((self.model_params['var_initial_mn'], p_initial_mn))
+        w0 = np.ravel(scipy.linalg.pinv(self.var_sids_x_rids_mn).dot(initial_mn) / self.model_params['rho'])
+        return w0
+
+    def get_mass_turnover(self, sids_x_rids):
+        # calculate the turnover in gram reactants/products per mol reactions
+        # from full stoichiometric matrix (including external metabolites)
+        sids_x_rids_react = np.where(sids_x_rids > 0, 0, sids_x_rids)
+        sids_x_rids_prod = np.where(sids_x_rids < 0, 0, sids_x_rids)
+        mass_turnover_react = -sids_x_rids_react.T.dot(self.mw_sids)
+        mass_turnover_prod = sids_x_rids_prod.T.dot(self.mw_sids)
+        return np.max((mass_turnover_react, mass_turnover_prod), axis=0)
+
+    def jtau(self, w, model_params):
+        ci_mass = self.jci(w, model_params)
+        ca_mass = jnp.hstack((model_params['ext_conc_mn'], ci_mass))
+        ca_molar = jnp.divide(ca_mass, self.mw_sids)
+        return jnp.divide(self.jrts(ca_molar, model_params), self.rt_conv_factor)
+
+    def jmu(self, w, model_params):
+        return jnp.divide(self.var_sids_x_rids_mn[-1, -1] * w[-1], self.jtau(w, model_params).dot(w))
+
+    # Density constraint
+    # sN[:, b].dot(w[b]) -1 , with b: indices of boundary reactions
+    def jg(self, w, model_params):
+        return jnp.dot(self.colsum_var_sids_x_rids_mn, w) - 1.0
+
+    # protein concentrations
+    def jp(self, w, model_params):
+        return self.jmu(w, model_params) * model_params['rho'] * self.jtau(w, model_params) * w
+
+    # internal concentrations "i" of substrates "s" and total protein "p"
+    def jci(self, w, model_params):
+        return model_params['rho'] * jnp.dot(self.var_sids_x_rids_mn, w)
+
+    # constraint of non-negative concentrations hin(x) ≥ 0 for nlopt, !! for scipy: ≤ 0
+    def jconcentrations(self, w, model_params):
+        return jnp.hstack((self.jci(w, model_params), self.jp(w, model_params)))
+
+    def _to_vectors(self, rs_str):
+        """replace species ids by vector elements, compartment and enzyme ids by 1.0
+
+        Variables converted to x[]
+
+        :param rs_str: reaction strings (extracted from the xba model).
+        :type rs_str: list of strings
+        :returns: rs_v_str
+        :rtype: list of reaction stings with parameters replaced by vector elements
+        """
+        # replace compartment variables with 1.0
+        # replace enzyme variables with 1.0
+        # replace species ids by vector components
+        # make model parameters accessible
+        rs_v_str = []
+        for r_v_str in rs_str:
+            for compartment_id in self.xba_model.compartments:
+                r_v_str = re.sub(r'\b' + compartment_id + r'\b', str(1.0), r_v_str)
+            for enz_sid in self.enz_sids:
+                r_v_str = re.sub(r'\b' + enz_sid + r'\b', str(1.0), r_v_str)
+            for sid, idx in self.sid2idx.items():
+                r_v_str = re.sub(r'\b' + sid + r'\b', f'x[{idx}]', r_v_str)
+            for param_key in self.model_params:
+                r_v_str = re.sub(r'\b' + param_key + r'\b', f'model_params["{param_key}"]', r_v_str)
+            rs_v_str.append(r_v_str)
+        return rs_v_str
+
+    def get_reaction_times(self, rids):
+        # molar reaction turnover times
+        # hard-code used namespace prefix for numpy 'np' to a fixed prefix on 'jnp'
+        # model parameters in reaction times functions as function parameter
+        turnover_times_str = [f'1.0/({self.xba_model.reactions[rid].expanded_kl})'
+                              for rid in rids]
+        vs_str = self._to_vectors(turnover_times_str)
+        func_code = compile('def turnover_times(x, model_params): '
+                            'return np.array([' + ', '.join(vs_str) + '])',
+                            '<string>', 'exec')
+        jrts = types.FunctionType(func_code.co_consts[0], {'np': jnp})
+        return jrts
+
     def get_unit_def_id(self, query_units):
         for udid, ud in self.xba_model.unit_defs.items():
             if ud.is_equivalent(query_units):
@@ -128,7 +238,6 @@ class GbaSfluxProblem:
         :return:
         :rtype:
         """
-
         r = self.xba_model.reactions[rid]
         kms_local = {pid: val for pid, [val, units] in r.local_params.items() if units == self.km_udid}
 
@@ -166,13 +275,23 @@ class GbaSfluxProblem:
         :return: gba model mass normalized as per Hugo Dourado
         :rtype: dict of pandas DataFrames
         """
-        protein_mn = self.rho - sum(self.var_initial_mn)
+        protein_mn = self.model_params['rho'] - sum(self.model_params['var_initial_mn'])
 
         kcats = np.array([self.get_kcats(rid) for rid in self.rids]).T
-        skcats = kcats * 3600.0 * self.mass_turnover / self.mw_enz
 
-        km = np.zeros_like(self.sidx_x_rids_mn)
-        ki = np.zeros_like(self.sidx_x_rids_mn)
+        # in the SBML model, the kcat for protein synthesis is scaled by protein lenght
+        s = self.xba_model.species[self.enz_sids[-1]]
+        if hasattr(s, 'rba_process_costs'):
+            p_len = s.rba_process_costs['translation']
+        else:
+            r = self.xba_model.reactions[self.rids[-1]]
+            p_len = min(r.reactants.values())
+
+        skcats = kcats * self.rt_conv_factor
+        skcats[:, -1] = skcats[:, -1] / p_len
+
+        km = np.zeros_like(self.sids_x_rids_mn)
+        ki = np.zeros_like(self.sids_x_rids_mn)
         for rid in self.rids:
             kms, kis = self.get_kms(rid)
             for sid, val in kms.items():
@@ -182,19 +301,19 @@ class GbaSfluxProblem:
         # scale Michaelis constants with molecluar weights
         skm = km * self.mw_sids.reshape(-1, 1)
         ski = ki * self.mw_sids.reshape(-1, 1)
-        ska = np.zeros_like(self.sidx_x_rids_mn)
+        ska = np.zeros_like(self.sids_x_rids_mn)
 
         mn_model = {
-            'N': pd.DataFrame(self.sidx_x_rids_mn, index=self.sids, columns=self.rids),
+            'N': pd.DataFrame(self.sids_x_rids_mn, index=self.sids, columns=self.rids),
             'KM': pd.DataFrame(skm, index=self.sids, columns=self.rids),
             'KI': pd.DataFrame(ski, index=self.sids, columns=self.rids),
             'KA': pd.DataFrame(ska, index=self.sids, columns=self.rids),
             'kcat': pd.DataFrame(skcats, index=['kcat_f', 'kcat_b'], columns=self.rids),
-            'conditions': pd.DataFrame(np.hstack((self.rho, self.ext_conc_mn)),
+            'conditions': pd.DataFrame(np.hstack((self.model_params['rho'], self.model_params['ext_conc_mn'])),
                                        index=['rho'] + self.const_metab_sids, columns=[1]),
             'lower_c': pd.DataFrame(np.zeros(len(self.var_sids)),
                                     index=self.var_sids, columns=['lower']),
-            'initial': pd.DataFrame(np.hstack((self.var_initial_mn, protein_mn)),
+            'initial': pd.DataFrame(np.hstack((self.model_params['var_initial_mn'], protein_mn)),
                                     index=self.var_sids, columns=['initial']),
         }
         return mn_model
diff --git a/xbanalysis/solvers/__init__.py b/xbanalysis/solvers/__init__.py
index 5907eaa15781be96589add06e40e15e15185ef89..4e4ee2ebb3ce25a72131cc942ddd2729f5b1ae21 100644
--- a/xbanalysis/solvers/__init__.py
+++ b/xbanalysis/solvers/__init__.py
@@ -2,5 +2,6 @@
 
 from .glpk_linear_problem import GlpkLinearProblem
 from .gba_ipopt_problem import GbaIpoptProblem
+from .gba_sflux_ipopt_problem import GbaSfluxIpoptProblem
 
-__all__ = ['GlpkLinearProblem', 'GbaIpoptProblem']
+__all__ = ['GlpkLinearProblem', 'GbaIpoptProblem', 'GbaSfluxIpoptProblem']
diff --git a/xbanalysis/solvers/gba_sflux_ipopt_problem.py b/xbanalysis/solvers/gba_sflux_ipopt_problem.py
new file mode 100644
index 0000000000000000000000000000000000000000..49718a3b9277bf112eabfb227ce2ba3f74850de3
--- /dev/null
+++ b/xbanalysis/solvers/gba_sflux_ipopt_problem.py
@@ -0,0 +1,75 @@
+"""Implementation of GbaSfluxIpoptProblem class.
+
+For solving GBA problem using scaled fluxes as per Hugo Dourado
+
+Peter Schubert, HHU Duesseldorf, September 2022
+"""
+
+import numpy as np
+
+
+# based on cyipopt examples
+class GbaSfluxIpoptProblem:
+
+    def __init__(self, gba_sflux_problem, model_params=None):
+        """Initialize sFluxIpoptProblem.
+
+        :param gba_sflux_problem: GbaSfluxProblem from where to extract relevant parameters
+        :type gba_sflux_problem: GbaSfluxProblem
+        :param model_params: model parameters (optional, default None, used params from GbaSfluxProblem
+        :type model_params: None or dict
+        """
+        self.sflux_problem = gba_sflux_problem
+        if model_params is None:
+            self.model_params = gba_sflux_problem.model_params
+        else:
+            self.model_params = model_params
+        self._n_vars = len(gba_sflux_problem.rids)
+        self.report_freq = 0
+
+        self.nfev = 0
+        self.njev = 0
+        self.nit = 0
+        self.constraint_dims = sum(gba_sflux_problem.n_constraints.values())
+
+    def objective(self, x):
+        self.nfev += 1
+        return -np.float64(self.sflux_problem.obj_mu(x, self.model_params))
+
+    def gradient(self, x):
+        self.njev += 1
+        return -np.array(self.sflux_problem.obj_mu_grad(x, self.model_params))
+
+    def constraints(self, x):
+        # TODO: improve code
+        eqs = np.hstack((self.sflux_problem.constr_density(x, self.model_params),
+                         self.sflux_problem.constrs_conc(x, self.model_params)))
+        return eqs
+
+    def jacobian(self, x):
+        # TODO: improve code
+        jacs = (np.vstack((self.sflux_problem.constr_density_grad(x, self.model_params),
+                           self.sflux_problem.constrs_conc_jac(x, self.model_params)))
+                ).flatten()
+        return jacs
+
+    # def hessianstructure(self):
+    #    return np.tril_indices(self._n_vars)
+
+    def hessian(self, x, lagrange, obj_factor):
+        hess = obj_factor * -self.sflux_problem.obj_mu_hess(x, self.model_params)
+        hesss = np.vstack(([self.sflux_problem.constr_density_hess(x, self.model_params)],
+                           self.sflux_problem.constrs_conc_hesss(x, self.model_params)))
+        hess += (lagrange.reshape((-1, 1, 1)) * hesss).sum(axis=0)
+        return hess[np.tril_indices(self._n_vars)]
+
+    def set_report_freq(self, mod_val=0):
+        self.report_freq = mod_val
+
+    # noinspection PyUnusedLocal
+    def intermediate(self, alg_mod, iter_count, obj_value, inf_pr, inf_du, mu,
+                     d_norm, regularization_size, alpha_du, alpha_pr, ls_trials):
+        self.nit = iter_count
+        if self.report_freq > 0:
+            if iter_count % self.report_freq == 0:
+                print(f'[{iter_count:5d}] growth rate: {-obj_value:.4f} 1/h')