Skip to content
Snippets Groups Projects
Select Git revision
  • 43bcfcc3258418589e5d6f9bdc89e1d845685ef7
  • master default protected
2 results

test_model.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    test_model.py 1.12 KiB
    # -*- coding: utf-8 -*-
    """Script to test ANN models.
    
    @author: Laura C. Kühle
    """
    import sys
    
    from tcd.ANN_Training import ModelTrainer, evaluate_models
    
    
    def main() -> None:
        """Test ANN models."""
        models = {}
        with open(str(snakemake.log[0]), 'w', encoding='utf-8') as logfile:
            sys.stdout = logfile
            sys.stderr = logfile
    
            # Initialize models to be evaluated
            for model in snakemake.params['models']:
                trainer = ModelTrainer(
                    {'model_name': model, 'dir': snakemake.params['directory'],
                     'model_dir': snakemake.params['directory'],
                     'random_seed': snakemake.params['random_seed'],
                     **snakemake.params['models'][model]})
                models[model] = trainer
    
            # Evaluate models
            evaluate_models(
                models=models, directory=snakemake.params['directory'],
                num_iterations=snakemake.params['num_iterations'],
                compare_normalization=snakemake.params['compare_normalization'])
    
    
    if __name__ == '__main__':
        if "snakemake" in locals():
            main()
        else:
            print('Not Defined.')