Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
T
Troubled Cell Detection
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Laura Christine Kühle
Troubled Cell Detection
Commits
bea05f67
Commit
bea05f67
authored
3 years ago
by
Laura Christine Kühle
Browse files
Options
Downloads
Patches
Plain Diff
Added documentation to 'ANN_Training'.
parent
1d1344ca
No related branches found
No related tags found
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
ANN_Training.py
+132
-10
132 additions, 10 deletions
ANN_Training.py
with
132 additions
and
10 deletions
ANN_Training.py
+
132
−
10
View file @
bea05f67
...
...
@@ -5,13 +5,14 @@
Code-Style: E226, W503
Docstring-Style: D200, D400
TODO: Add documentation
TODO: Add documentation
-> Done
TODO: Add README for ANN training
TODO: Fix random seed
TODO: Write-protect all data and models
TODO: Put legend outside plot (bbox_to_anchor)
TODO: Put plotting into separate function
TODO: Reduce number of testing epochs to 50
TODO: Adapt docstring to uniform standard
"""
import
numpy
as
np
...
...
@@ -32,10 +33,51 @@ matplotlib.use('Agg')
class
ModelTrainer
(
object
):
def
__init__
(
self
,
config
):
"""
Class for ANN model training.
Trains and tests a model with set loss function and optimizer.
Attributes
----------
model : torch.nn.Module
ANN model instance for evaluation.
loss_function : torch.nn.modules.loss
Function to evaluate loss during model training.
optimizer : torch.optim
Optimizer for model training.
validation_loss : torch.Tensor
List of validation loss values during training.
Methods
-------
epoch_training()
Trains model for a given number of epochs.
test_model()
Evaluates predictions of a model.
save_model()
Saves state and validation loss of a model.
"""
def
__init__
(
self
,
config
:
dict
)
->
None
:
"""
Initializes ModelTrainer.
Parameters
----------
config : dict
Additional parameters for model trainer.
"""
self
.
_reset
(
config
)
def
_reset
(
self
,
config
):
def
_reset
(
self
,
config
:
dict
)
->
None
:
"""
Resets instance variables.
Parameters
----------
config : dict
Additional parameters for model trainer.
"""
self
.
_batch_size
=
config
.
pop
(
'
batch_size
'
,
500
)
self
.
_num_epochs
=
config
.
pop
(
'
num_epochs
'
,
1000
)
self
.
_threshold
=
config
.
pop
(
'
threshold
'
,
1e-5
)
...
...
@@ -64,8 +106,27 @@ class ModelTrainer(object):
self
.
_optimizer
=
getattr
(
torch
.
optim
,
optimizer
)(
self
.
_model
.
parameters
(),
**
optimizer_config
)
self
.
_validation_loss
=
torch
.
zeros
(
self
.
_num_epochs
//
10
)
print
(
type
(
self
.
_model
),
type
(
self
.
_loss_function
),
type
(
self
.
_optimizer
),
type
(
self
.
_validation_loss
))
def
epoch_training
(
self
,
dataset
:
torch
.
utils
.
data
.
dataset
.
TensorDataset
,
num_epochs
:
int
=
None
,
verbose
:
bool
=
True
)
->
None
:
"""
Trains model for a given number of epochs.
Trains model and saves the validation loss. The training stops after the given number of
epochs or if the threshold is reached.
Parameters
----------
dataset : torch.utils.data.dataset.TensorDataset
Training dataset.
num_epochs : int, optional
Number of epochs for training. If None, set to instance value. Default: None.
verbose : bool, optional
Flag whether commentary in console is wanted. Default: False.
def
epoch_training
(
self
,
dataset
,
num_epochs
=
None
,
verbose
=
True
):
"""
print
(
type
(
dataset
))
tic
=
time
.
perf_counter
()
if
num_epochs
is
None
:
num_epochs
=
self
.
_num_epochs
...
...
@@ -117,7 +178,26 @@ class ModelTrainer(object):
if
verbose
:
print
(
f
'
Total runtime:
{
toc
-
tic
:
0.4
f
}
s
\n
'
)
def
test_model
(
self
,
training_set
,
test_set
):
def
test_model
(
self
,
training_set
:
torch
.
utils
.
data
.
dataset
.
TensorDataset
,
test_set
:
torch
.
utils
.
data
.
dataset
.
TensorDataset
)
->
dict
:
"""
Evaluates predictions of a model.
Trains a model and compares the predicted and true results by evaluating precision, recall,
and f-score for both classes, as well as accuracy and AUROC score.
Parameters
----------
training_set : torch.utils.data.dataset.TensorDataset
Training dataset.
test_set : torch.utils.data.dataset.TensorDataset
Test dataset.
Returns
-------
dict
Dictionary containing classification evaluation data.
"""
self
.
epoch_training
(
training_set
,
num_epochs
=
100
,
verbose
=
False
)
self
.
_model
.
eval
()
...
...
@@ -137,7 +217,17 @@ class ModelTrainer(object):
'
F-Score_Smooth
'
:
f_score
[
0
],
'
F-Score_Troubled
'
:
f_score
[
1
],
'
Accuracy
'
:
accuracy
,
'
AUROC
'
:
auroc
}
def
save_model
(
self
,
directory
,
model_name
=
'
test_model
'
):
def
save_model
(
self
,
directory
:
str
,
model_name
:
str
=
'
test_model
'
)
->
None
:
"""
Saves state and validation loss of a model.
Parameters
----------
directory: str
Path to directory in which model is saved.
model_name: str, optional
Name of model for saving. Default:
'
test_model
'
.
"""
# Set paths for files if not existing already
model_dir
=
directory
+
'
/trained models
'
if
not
os
.
path
.
exists
(
model_dir
):
...
...
@@ -148,15 +238,47 @@ class ModelTrainer(object):
torch
.
save
(
self
.
_validation_loss
,
model_dir
+
'
/loss__
'
+
model_name
+
'
.pt
'
)
def
read_training_data
(
directory
,
normalized
=
True
):
def
read_training_data
(
directory
:
str
,
normalized
:
bool
=
True
)
->
torch
.
utils
.
data
.
dataset
.
TensorDataset
:
"""
Reads training data from directory.
Parameters
----------
directory: str
Path to directory in which training data is saved.
normalized: bool, optional
Flag whether normalized data should be used. Default: True.
Returns
-------
torch.utils.data.dataset.TensorDataset
Training dataset.
"""
# Get training dataset from saved file and map to Torch tensor and dataset
input_file
=
directory
+
(
'
/normalized_input_data.npy
'
if
normalized
else
'
/input_data.npy
'
)
output_file
=
directory
+
'
/output_data.npy
'
return
TensorDataset
(
*
map
(
torch
.
tensor
,
(
np
.
load
(
input_file
),
np
.
load
(
output_file
))))
def
evaluate_models
(
models
,
directory
,
num_iterations
=
100
,
colors
=
None
,
compare_normalization
=
False
):
def
evaluate_models
(
models
:
dict
,
directory
:
str
,
num_iterations
:
int
=
100
,
colors
:
dict
=
None
,
compare_normalization
:
bool
=
False
)
->
None
:
"""
Evaluates the classification of a given set of models.
Parameters
----------
models: dict
Dictionary of models to evaluate.
directory: str
Path to directory for saving resulting plots.
num_iterations: int, optional
Number of iterations for evaluation. Default: 100.
colors: dict, optional
Dictionary containing plotting colors. If None, set to default colors. Default: None.
compare_normalization: bool, optional
Flag whether both normalized and raw data should be evaluated. Default: False.
"""
tic
=
time
.
perf_counter
()
if
colors
is
None
:
colors
=
{
'
Accuracy
'
:
'
magenta
'
,
'
Precision_Smooth
'
:
'
red
'
,
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment