diff --git a/xbanalysis/problems/gba_problem.py b/xbanalysis/problems/gba_problem.py index 0eccbb6e93ab4464aa8b99b1404d0f7479abaf55..ee305403a886caa00653a25a65dfd46daa72072a 100644 --- a/xbanalysis/problems/gba_problem.py +++ b/xbanalysis/problems/gba_problem.py @@ -28,6 +28,7 @@ def cache_last(func): # TODO: call get_reaction_rates from solver ? # TODO: split get_reaction_rates, get_reaction_jacs, get_reaction_hesses # TODO: number of mb equations, enzyme transitions density constraints +# number of constraint from len(var_metab_sids), subenz_x_trans.shape[0], 1 class GbaProblem: @@ -113,6 +114,9 @@ class GbaProblem: self.pstmas, self.pstmas_jac, self.pstmas_hesss = self.get_reactions(self.ps_rids, inv=True) self.degradas, self.degradas_jac, self.degradas_hesss = self.get_reactions(enz_degrad_rids, inv=False) self.enz_transas, self.enz_transas_jac, self.enz_transas_hesss = self.get_reactions(enz_trans_rids, inv=False) + self.n_constraints = {'heq_mass_balance': len(self.var_metab_sids), + 'heq_enz_trans': self.subenz_x_trans.shape[0], + 'heq_density': 1} def get_stoic_matrix(self, sids, rids): """retrieve stoichiometric sub-matrix [sids x rids]. diff --git a/xbanalysis/solvers/gba_ipopt_problem.py b/xbanalysis/solvers/gba_ipopt_problem.py index 480c30984387b401a5939051a05c57c4f7efe3c1..7f3322bd1b1c09b2651574c386897a0b8f24b8ff 100644 --- a/xbanalysis/solvers/gba_ipopt_problem.py +++ b/xbanalysis/solvers/gba_ipopt_problem.py @@ -22,10 +22,7 @@ class GbaIpoptProblem: self.nfev = 0 self.njev = 0 self.nit = 0 - tmp_x = np.ones(self._n_vars) * .01 - self.constraint_dims = np.array([len(gba_problem.heq_mass_balance(tmp_x)), - len(gba_problem.heq_enz_trans(tmp_x)), - len(gba_problem.heq_density(tmp_x))]) + self.constraint_dims = sum(gba_problem.n_constraints.values()) def objective(self, x): self.nfev += 1 @@ -36,24 +33,41 @@ class GbaIpoptProblem: return -self.gba_problem.get_growth_rate_grad(x) def constraints(self, x): - return np.concatenate((self.gba_problem.heq_mass_balance(x), - self.gba_problem.heq_enz_trans(x), - self.gba_problem.heq_density(x))) + # TODO: improve code + if self.gba_problem.n_constraints['heq_enz_trans'] > 0: + eqs = np.hstack((self.gba_problem.heq_mass_balance(x), + self.gba_problem.heq_enz_trans(x), + self.gba_problem.heq_density(x))) + else: + eqs = np.hstack((self.gba_problem.heq_mass_balance(x), + self.gba_problem.heq_density(x))) + return eqs def jacobian(self, x): - return (np.vstack((self.gba_problem.heq_mass_balance_jac(x), - self.gba_problem.heq_enz_trans_jac(x), - self.gba_problem.heq_density_jac(x))) - ).flatten() + # TODO: improve code + if self.gba_problem.n_constraints['heq_enz_trans'] > 0: + jacs = (np.vstack((self.gba_problem.heq_mass_balance_jac(x), + self.gba_problem.heq_enz_trans_jac(x), + self.gba_problem.heq_density_jac(x))) + ).flatten() + else: + jacs = (np.vstack((self.gba_problem.heq_mass_balance_jac(x), + self.gba_problem.heq_density_jac(x))) + ).flatten() + return jacs def hessianstructure(self): return np.tril_indices(self._n_vars) def hessian(self, x, lagrange, obj_factor): hess = obj_factor * -self.gba_problem.get_growth_rate_hess(x) - hesss = np.vstack((self.gba_problem.heq_mass_balance_hess(x), - self.gba_problem.heq_enz_trans_hess(x), - self.gba_problem.heq_density_hess(x))) + if self.gba_problem.n_constraints['heq_enz_trans'] > 0: + hesss = np.vstack((self.gba_problem.heq_mass_balance_hess(x), + self.gba_problem.heq_enz_trans_hess(x), + self.gba_problem.heq_density_hess(x))) + else: + hesss = np.vstack((self.gba_problem.heq_mass_balance_hess(x), + self.gba_problem.heq_density_hess(x))) hess += (lagrange.reshape((-1, 1, 1)) * hesss).sum(axis=0) return hess[np.tril_indices(self._n_vars)]