diff --git a/modelpruner/core/model_pruner.py b/modelpruner/core/model_pruner.py index 7c97eb3f9285687bc7219d2e8739df4dbe0ba890..9a841427e29870557ee59be3193d8e273aed195e 100644 --- a/modelpruner/core/model_pruner.py +++ b/modelpruner/core/model_pruner.py @@ -6,6 +6,7 @@ import re import time import math import numpy as np +import pandas as pd import multiprocessing as mp import json @@ -78,7 +79,7 @@ class ModelPruner: :param processes: maximal number of processes for multiprocessing :type processes: int (default: 100) """ - self.tolerance = 1e-7 # later accessible by set_params() ? + self.tolerance = 1e-12 # later accessible by set_params() ? self.cpus = min(processes, os.cpu_count()) self.min_mp_total_flux = 1000 self.full_sbmnl_fname = sbml_fname @@ -89,8 +90,8 @@ class ModelPruner: if resume is True and os.path.exists(self._snapshot_sbml): sbml_fname = self._snapshot_sbml self.fba_model = FbaModel(sbml_fname) - self.nrp = ProtectedParts(protected_parts) - self.protected_sids = self.nrp.protected_sids + self.pps = ProtectedParts(protected_parts) + self.protected_sids = self.pps.protected_sids if resume is True and os.path.exists(self._snapshot_json): with open(self._snapshot_json, 'r') as f: @@ -99,11 +100,11 @@ class ModelPruner: print(f'Snapshot restored at {len(self.fba_model.rids):4d} remaining reactions ' f'from: {self._snapshot_json}') else: - self.protected_rids = self.nrp.protected_rids + self.protected_rids = self.pps.protected_rids # self.protected_rids = self.get_total_protected_rids() # overwrite the fba model bounds by general bound of protected parts (priot to LpProblem instantiation) - self.fba_model.update_model_flux_bounds(self.nrp.overwrite_bounds) + self.fba_model.update_model_flux_bounds(self.pps.overwrite_bounds) self.fba_model.update_model_reversible(self.get_pp_reversible_reactions()) print('Determine optimal values for protected functions') @@ -124,7 +125,7 @@ class ModelPruner: for sid in self.fba_model.sids]) psf_mask = self.fba_model.s_mat_coo.tocsr()[sid_mask].getnnz(axis=0) protected_sid_rids = set(self.fba_model.rids[psf_mask]) - protected_rids = self.nrp.protected_rids.union(protected_sid_rids) + protected_rids = self.pps.protected_rids.union(protected_sid_rids) return protected_rids def get_pp_reversible_reactions(self): @@ -134,11 +135,11 @@ class ModelPruner: :rtype: set of str """ rev_pp_rids = set() - for rid, bounds in self.nrp.overwrite_bounds.items(): + for rid, bounds in self.pps.overwrite_bounds.items(): for idx in [0, 1]: if math.isfinite(bounds[idx]): rev_pp_rids.add(rid) - for pf in self.nrp.protected_functions.values(): + for pf in self.pps.protected_functions.values(): for rid, bounds in pf.overwrite_bounds.items(): for idx in [0, 1]: if math.isfinite(bounds[idx]) and bounds[idx] < 0.0: @@ -147,7 +148,7 @@ class ModelPruner: def set_function_optimum(self): """using FBA to set optimal value for protected functions""" - for pf in self.nrp.protected_functions.values(): + for pf in self.pps.protected_functions.values(): res = self.fba_model.fba_pf_optimize(pf) pf.optimal = res['fun'] pf.fba_success = res['success'] @@ -158,14 +159,14 @@ class ModelPruner: pf.set_target_ub(pf.target_frac * pf.optimal) def check_protected_parts(self): - """Report on consisteny issues wrt to protected parts. + """Report on consistency issues wrt to protected parts. E.g. report on reactions and metabolties that do not exist in the model """ extra_p_rids = self.protected_rids.difference(set(self.fba_model.rids)) extra_p_sids = self.protected_sids.difference(set(self.fba_model.sids)) pf_rids = set() - for pf in self.nrp.protected_functions.values(): + for pf in self.pps.protected_functions.values(): pf_rids |= {rid for rid in pf.objective} pf_rids |= {rid for rid in pf.overwrite_bounds} if pf.fba_success is False: @@ -184,7 +185,7 @@ class ModelPruner: print('Reactions in protected functions that are not protected:', pf_unprotected_rids) blocked_rids = self.reaction_types_fva(self.fba_model.rids)['blocked_rids'] - blocked_p_rids = self.nrp.protected_sids.intersection(blocked_rids) + blocked_p_rids = self.pps.protected_sids.intersection(blocked_rids) blocked_pf_rids = pf_rids.intersection(blocked_rids) if len(blocked_p_rids) > 0: print('Protected reactions that are blocked (no flux):', blocked_p_rids) @@ -230,7 +231,7 @@ class ModelPruner: """ print(time.strftime("%H:%M:%S", time.localtime())) - n_pf = len(self.nrp.protected_functions) + n_pf = len(self.pps.protected_functions) flux_min_pfs = np.zeros((n_pf, len(free_rids))) flux_max_pfs = np.zeros((n_pf, len(free_rids))) @@ -246,13 +247,13 @@ class ModelPruner: # delete lp problem prior to pool creation (so FbaProblem can be pickled) self.fba_model.delete_fba_lp() with mp.Pool(processes, initializer=_init_worker, - initargs=(self.fba_model, self.nrp.protected_functions)) as pool: + initargs=(self.fba_model, self.pps.protected_functions)) as pool: for res in pool.imap_unordered(_worker, np.array_split(free_rids, processes)): idxs = [frid2idx[rid] for rid in res['rids']] flux_min_pfs[:, idxs] = res['min_pfs'] flux_max_pfs[:, idxs] = res['max_pfs'] else: - for idx, pf in enumerate(self.nrp.protected_functions.values()): + for idx, pf in enumerate(self.pps.protected_functions.values()): # print(f'\nCheck FVA for pf {pf.obj_id} with {len(free_rids)} free variables') res = self.fba_model.fva_pf_flux_ranges(pf, free_rids) if res['success'] is True: @@ -330,7 +331,7 @@ class ModelPruner: Step during Pruning 1. identification and removal of parallel reactions - 2. check feasibility of the network wrt to all protected functions + 2. check feasibility of the network wrt to all protected functions and reactions 3. in a loop till all free reactions have been consumed 3.1 identify types of remaining free reactions using FVA across all protected functions 3.2 blocked reactions (zero flux across all functions) get removed @@ -348,7 +349,8 @@ class ModelPruner: self.fba_model.remove_reactions(drop_rids) print(f'{len(drop_rids)} parallel reaction(s) dropped:', drop_rids) - feasible = np.array([self.fba_model.check_pf_feasibility(pf) for pf in self.nrp.protected_functions.values()]) + feasible = np.array([self.fba_model.check_pf_feasibility(pf) + for pf in self.pps.protected_functions.values()]) print('Protected functions feasibility:', np.all(feasible)) next_snapshot = 0 @@ -371,7 +373,7 @@ class ModelPruner: for rid in candidate_rids: # print('Check feasibility of', rid) feasible = True - for pf in self.nrp.protected_functions.values(): + for pf in self.pps.protected_functions.values(): feasible = self.fba_model.check_pf_feasibility(pf, zero_flux_rid=rid) # print(f'Feasibility of {rid} is {feasible}') if feasible is False: @@ -447,7 +449,14 @@ class ModelPruner: if pruned_sbml is None: pruned_sbml = re.sub(r'.xml$', f'_pruned_{n_s}x{n_r}.xml', self.full_sbmnl_fname) - success = self.fba_model.export_pruned_model(pruned_sbml) + # overwrite model objective with objective of first protected function + first_pf = list(self.pps.protected_functions.values())[0] + fba_objective = '; '.join([f'reac={reac}, coef={coef}' + for reac, coef in first_pf.objective.items()]) + df_fbc_objectives = pd.DataFrame(np.array([[first_pf.obj_dir, True, fba_objective]]), + index=['obj'], columns=['type', 'active', 'fluxObjectives']) + df_fbc_objectives.index.name = 'id' + success = self.fba_model.export_pruned_model(pruned_sbml, df_fbc_objectives) if success is True: print('Pruned model exported to', pruned_sbml) if len(self.protected_rids) == n_r: