Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
E
emoUS-public
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
general
dsml
emoUS-public
Commits
c8c19090
Commit
c8c19090
authored
Aug 4, 2022
by
Carel van Niekerk
Browse files
Options
Downloads
Patches
Plain Diff
Refactor and rename calibration to evaluate
parent
06491f2b
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
convlab/dst/setsumbt/do/evaluate.py
+71
-161
71 additions, 161 deletions
convlab/dst/setsumbt/do/evaluate.py
convlab/dst/setsumbt/modeling/evaluation_utils.py
+5
-3
5 additions, 3 deletions
convlab/dst/setsumbt/modeling/evaluation_utils.py
with
76 additions
and
164 deletions
convlab/dst/setsumbt/do/
calibration
.py
→
convlab/dst/setsumbt/do/
evaluate
.py
+
71
−
161
View file @
c8c19090
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# Copyright 202
1
DSML Group, Heinrich Heine University, Düsseldorf
# Copyright 202
2
DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -16,33 +16,22 @@
...
@@ -16,33 +16,22 @@
"""
Run SetSUMBT Calibration
"""
"""
Run SetSUMBT Calibration
"""
import
logging
import
logging
import
random
import
os
import
os
from
shutil
import
copy2
as
copy
import
torch
import
torch
from
transformers
import
(
BertModel
,
BertConfig
,
BertTokenizer
,
from
transformers
import
(
BertModel
,
BertConfig
,
BertTokenizer
,
RobertaModel
,
RobertaConfig
,
RobertaTokenizer
,
RobertaModel
,
RobertaConfig
,
RobertaTokenizer
)
AdamW
,
get_linear_schedule_with_warmup
)
from
tqdm
import
tqdm
,
trange
from
convlab.dst.setsumbt.modeling
import
BertSetSUMBT
,
RobertaSetSUMBT
from
tensorboardX
import
SummaryWriter
from
convlab.dst.setsumbt.dataset
import
unified_format
from
torch.distributions
import
Categorical
from
convlab.dst.setsumbt.dataset
import
ontology
as
embeddings
from
convlab.dst.setsumbt.utils
import
get_args
,
update_args
from
convlab.dst.setsumbt.modeling.bert_nbt
import
BertSetSUMBT
from
convlab.dst.setsumbt.modeling
import
evaluation_utils
from
convlab.dst.setsumbt.modeling.roberta_nbt
import
RobertaSetSUMBT
from
convlab.dst.setsumbt.loss.uncertainty_measures
import
ece
,
jg_ece
,
l2_acc
from
convlab.dst.setsumbt.multiwoz
import
multiwoz21
from
convlab.dst.setsumbt.modeling
import
training
from
convlab.dst.setsumbt.multiwoz
import
ontology
as
embeddings
from
convlab.dst.setsumbt.utils
import
get_args
,
upload_local_directory_to_gcs
,
update_args
from
convlab.dst.setsumbt.modeling
import
calibration_utils
from
convlab.dst.setsumbt.modeling
import
ensemble_utils
from
convlab.dst.setsumbt.loss.ece
import
ece
,
jg_ece
,
l2_acc
# Datasets
DATASETS
=
{
'
multiwoz21
'
:
multiwoz21
}
# Available model
MODELS
=
{
MODELS
=
{
'
bert
'
:
(
BertSetSUMBT
,
BertModel
,
BertConfig
,
BertTokenizer
),
'
bert
'
:
(
BertSetSUMBT
,
BertModel
,
BertConfig
,
BertTokenizer
),
'
roberta
'
:
(
RobertaSetSUMBT
,
RobertaModel
,
RobertaConfig
,
RobertaTokenizer
)
'
roberta
'
:
(
RobertaSetSUMBT
,
RobertaModel
,
RobertaConfig
,
RobertaTokenizer
)
...
@@ -54,12 +43,6 @@ def main(args=None, config=None):
...
@@ -54,12 +43,6 @@ def main(args=None, config=None):
if
args
is
None
:
if
args
is
None
:
args
,
config
=
get_args
(
MODELS
)
args
,
config
=
get_args
(
MODELS
)
# Select Dataset object
if
args
.
dataset
in
DATASETS
:
Dataset
=
DATASETS
[
args
.
dataset
]
else
:
raise
NameError
(
'
NotImplemented
'
)
if
args
.
model_type
in
MODELS
:
if
args
.
model_type
in
MODELS
:
SetSumbtModel
,
CandidateEncoderModel
,
ConfigClass
,
Tokenizer
=
MODELS
[
args
.
model_type
]
SetSumbtModel
,
CandidateEncoderModel
,
ConfigClass
,
Tokenizer
=
MODELS
[
args
.
model_type
]
else
:
else
:
...
@@ -67,69 +50,35 @@ def main(args=None, config=None):
...
@@ -67,69 +50,35 @@ def main(args=None, config=None):
# Set up output directory
# Set up output directory
OUTPUT_DIR
=
args
.
output_dir
OUTPUT_DIR
=
args
.
output_dir
if
not
os
.
path
.
exists
(
OUTPUT_DIR
):
os
.
mkdir
(
OUTPUT_DIR
)
args
.
output_dir
=
OUTPUT_DIR
args
.
output_dir
=
OUTPUT_DIR
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
predictions
'
)):
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
predictions
'
)):
os
.
mkdir
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
predictions
'
))
os
.
mkdir
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
predictions
'
))
paths
=
os
.
listdir
(
args
.
output_dir
)
if
os
.
path
.
exists
(
# Set pretrained model path to the trained checkpoint
args
.
output_dir
)
else
[]
paths
=
os
.
listdir
(
args
.
output_dir
)
if
os
.
path
.
exists
(
args
.
output_dir
)
else
[]
if
'
pytorch_model.bin
'
in
paths
and
'
config.json
'
in
paths
:
if
'
pytorch_model.bin
'
in
paths
and
'
config.json
'
in
paths
:
args
.
model_name_or_path
=
args
.
output_dir
args
.
model_name_or_path
=
args
.
output_dir
config
=
ConfigClass
.
from_pretrained
(
args
.
model_name_or_path
)
config
=
ConfigClass
.
from_pretrained
(
args
.
model_name_or_path
)
else
:
else
:
paths
=
os
.
listdir
(
args
.
output_dir
)
if
os
.
path
.
exists
(
paths
=
[
os
.
path
.
join
(
args
.
output_dir
,
p
)
for
p
in
paths
if
'
checkpoint-
'
in
p
]
args
.
output_dir
)
else
[]
paths
=
[
os
.
path
.
join
(
args
.
output_dir
,
p
)
for
p
in
paths
if
'
checkpoint-
'
in
p
]
if
paths
:
if
paths
:
paths
=
paths
[
0
]
paths
=
paths
[
0
]
args
.
model_name_or_path
=
paths
args
.
model_name_or_path
=
paths
config
=
ConfigClass
.
from_pretrained
(
args
.
model_name_or_path
)
config
=
ConfigClass
.
from_pretrained
(
args
.
model_name_or_path
)
if
args
.
ensemble_size
>
0
:
paths
=
os
.
listdir
(
args
.
output_dir
)
if
os
.
path
.
exists
(
args
.
output_dir
)
else
[]
paths
=
[
os
.
path
.
join
(
args
.
output_dir
,
p
)
for
p
in
paths
if
'
ensemble_
'
in
p
]
if
paths
:
args
.
model_name_or_path
=
args
.
output_dir
config
=
ConfigClass
.
from_pretrained
(
args
.
model_name_or_path
)
args
=
update_args
(
args
,
config
)
args
=
update_args
(
args
,
config
)
# Set up data directory
DATA_DIR
=
args
.
data_dir
Dataset
.
set_datadir
(
DATA_DIR
)
embeddings
.
set_datadir
(
DATA_DIR
)
if
args
.
shrink_active_domains
and
args
.
dataset
==
'
multiwoz21
'
:
Dataset
.
set_active_domains
(
[
'
attraction
'
,
'
hotel
'
,
'
restaurant
'
,
'
taxi
'
,
'
train
'
])
# Download and preprocess
Dataset
.
create_examples
(
args
.
max_turn_len
,
args
.
predict_intents
,
args
.
force_processing
)
# Create logger
# Create logger
global
logger
global
logger
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
INFO
)
logger
.
setLevel
(
logging
.
INFO
)
formatter
=
logging
.
Formatter
(
formatter
=
logging
.
Formatter
(
'
%(asctime)s - %(name)s - %(levelname)s - %(message)s
'
)
'
%(asctime)s - %(name)s - %(levelname)s - %(message)s
'
)
if
'
stream
'
not
in
args
.
logging_path
:
fh
=
logging
.
FileHandler
(
args
.
logging_path
)
fh
=
logging
.
FileHandler
(
args
.
logging_path
)
fh
.
setLevel
(
logging
.
INFO
)
fh
.
setLevel
(
logging
.
INFO
)
fh
.
setFormatter
(
formatter
)
fh
.
setFormatter
(
formatter
)
logger
.
addHandler
(
fh
)
logger
.
addHandler
(
fh
)
else
:
ch
=
logging
.
StreamHandler
()
ch
.
setLevel
(
level
=
logging
.
INFO
)
ch
.
setFormatter
(
formatter
)
logger
.
addHandler
(
ch
)
# Get device
# Get device
if
torch
.
cuda
.
is_available
()
and
args
.
n_gpu
>
0
:
if
torch
.
cuda
.
is_available
()
and
args
.
n_gpu
>
0
:
...
@@ -142,18 +91,12 @@ def main(args=None, config=None):
...
@@ -142,18 +91,12 @@ def main(args=None, config=None):
args
.
fp16
=
False
args
.
fp16
=
False
# Set up model training/evaluation
# Set up model training/evaluation
calibration
.
set_logger
(
logger
,
None
)
evaluation_utils
.
set_logger
(
logger
,
None
)
calibration
.
set_seed
(
args
)
evaluation_utils
.
set_seed
(
args
)
if
args
.
ensemble_size
>
0
:
ensemble
.
set_logger
(
logger
,
tb_writer
)
ensemble_utils
.
set_seed
(
args
)
# Perform tasks
# Perform tasks
if
os
.
path
.
exists
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
predictions
'
,
'
test.predictions
'
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
predictions
'
,
'
test.predictions
'
)):
pred
=
torch
.
load
(
os
.
path
.
join
(
pred
=
torch
.
load
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
predictions
'
,
'
test.predictions
'
))
OUTPUT_DIR
,
'
predictions
'
,
'
test.predictions
'
))
labels
=
pred
[
'
labels
'
]
labels
=
pred
[
'
labels
'
]
belief_states
=
pred
[
'
belief_states
'
]
belief_states
=
pred
[
'
belief_states
'
]
if
'
request_labels
'
in
pred
:
if
'
request_labels
'
in
pred
:
...
@@ -166,100 +109,41 @@ def main(args=None, config=None):
...
@@ -166,100 +109,41 @@ def main(args=None, config=None):
else
:
else
:
request_belief
=
None
request_belief
=
None
del
pred
del
pred
elif
args
.
ensemble_size
>
0
:
# Get training batch loaders and ontology embeddings
if
os
.
path
.
exists
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
database
'
,
'
test.db
'
)):
test_slots
=
torch
.
load
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
database
'
,
'
test.db
'
))
else
:
else
:
# Create Tokenizer and embedding model for Data Loaders and ontology
# Get training batch loaders and ontology embeddings
encoder
=
CandidateEncoderModel
.
from_pretrained
(
config
.
candidate_embedding_model_name
)
tokenizer
=
Tokenizer
(
config
.
candidate_embedding_model_name
)
embeddings
.
get_slot_candidate_embeddings
(
'
test
'
,
args
,
tokenizer
,
encoder
)
test_slots
=
torch
.
load
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
database
'
,
'
test.db
'
))
exists
=
False
if
os
.
path
.
exists
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
dataloaders
'
,
'
test.dataloader
'
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
dataloaders
'
,
'
test.dataloader
'
)):
test_dataloader
=
torch
.
load
(
os
.
path
.
join
(
test_dataloader
=
torch
.
load
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
dataloaders
'
,
'
test.dataloader
'
))
OUTPUT_DIR
,
'
dataloaders
'
,
'
test.dataloader
'
))
if
test_dataloader
.
batch_size
!=
args
.
test_batch_size
:
if
test_dataloader
.
batch_size
==
args
.
test_batch_size
:
test_dataloader
=
unified_format
.
change_batch_size
(
test_dataloader
,
args
.
test_batch_size
)
exists
=
True
else
:
if
not
exists
:
tokenizer
=
Tokenizer
(
config
.
candidate_embedding_model_name
)
tokenizer
=
Tokenizer
(
config
.
candidate_embedding_model_name
)
test_dataloader
=
Dataset
.
get_dataloader
(
'
test
'
,
args
.
test_batch_size
,
tokenizer
,
args
.
max_dialogue_len
,
test_dataloader
=
unified_format
.
get_dataloader
(
args
.
dataset
,
'
test
'
,
args
.
test_batch_size
,
tokenizer
,
args
.
max_dialogue_len
,
config
.
max_turn_len
)
config
.
max_turn_len
)
torch
.
save
(
test_dataloader
,
os
.
path
.
join
(
torch
.
save
(
test_dataloader
,
os
.
path
.
join
(
OUTPUT_DIR
,
'
dataloaders
'
,
'
test.dataloader
'
))
OUTPUT_DIR
,
'
dataloaders
'
,
'
test.dataloader
'
))
config
,
models
=
ensemble
.
get_models
(
args
.
model_name_or_path
,
device
,
ConfigClass
,
SetSumbtModel
)
belief_states
,
labels
=
ensemble_utils
.
get_predictions
(
args
,
models
,
device
,
test_dataloader
,
test_slots
)
torch
.
save
({
'
belief_states
'
:
belief_states
,
'
labels
'
:
labels
},
os
.
path
.
join
(
OUTPUT_DIR
,
'
predictions
'
,
'
test.predictions
'
))
else
:
# Get training batch loaders and ontology embeddings
if
os
.
path
.
exists
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
database
'
,
'
test.db
'
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
database
'
,
'
test.db
'
)):
test_slots
=
torch
.
load
(
os
.
path
.
join
(
test_slots
=
torch
.
load
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
database
'
,
'
test.db
'
))
OUTPUT_DIR
,
'
database
'
,
'
test.db
'
))
else
:
else
:
# Create Tokenizer and embedding model for Data Loaders and ontology
encoder
=
CandidateEncoderModel
.
from_pretrained
(
config
.
candidate_embedding_model_name
)
encoder
=
CandidateEncoderModel
.
from_pretrained
(
test_slots
=
embeddings
.
get_slot_candidate_embeddings
(
test_dataloader
.
dataset
.
ontology
,
config
.
candidate_embedding_model_name
)
tokenizer
=
Tokenizer
(
config
.
candidate_embedding_model_name
)
embeddings
.
get_slot_candidate_embeddings
(
'
test
'
,
args
,
tokenizer
,
encoder
)
'
test
'
,
args
,
tokenizer
,
encoder
)
test_slots
=
torch
.
load
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
database
'
,
'
test.db
'
))
exists
=
False
if
os
.
path
.
exists
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
dataloaders
'
,
'
test.dataloader
'
)):
test_dataloader
=
torch
.
load
(
os
.
path
.
join
(
OUTPUT_DIR
,
'
dataloaders
'
,
'
test.dataloader
'
))
if
test_dataloader
.
batch_size
==
args
.
test_batch_size
:
exists
=
True
if
not
exists
:
tokenizer
=
Tokenizer
(
config
.
candidate_embedding_model_name
)
test_dataloader
=
Dataset
.
get_dataloader
(
'
test
'
,
args
.
test_batch_size
,
tokenizer
,
args
.
max_dialogue_len
,
config
.
max_turn_len
)
torch
.
save
(
test_dataloader
,
os
.
path
.
join
(
OUTPUT_DIR
,
'
dataloaders
'
,
'
test.dataloader
'
))
# Initialise Model
# Initialise Model
model
=
SetSumbtModel
.
from_pretrained
(
model
=
SetSumbtModel
.
from_pretrained
(
args
.
model_name_or_path
,
config
=
config
)
args
.
model_name_or_path
,
config
=
config
)
model
=
model
.
to
(
device
)
model
=
model
.
to
(
device
)
# Get slot and value embeddings
training
.
set_ontology_embeddings
(
model
,
test_slots
)
slots
=
{
slot
:
test_slots
[
slot
]
for
slot
in
test_slots
}
values
=
{
slot
:
test_slots
[
slot
][
1
]
for
slot
in
test_slots
}
# Load model ontology
model
.
add_slot_candidates
(
slots
)
for
slot
in
model
.
informable_slot_ids
:
model
.
add_value_candidates
(
slot
,
values
[
slot
],
replace
=
True
)
belief_states
=
calibration
.
get_predictions
(
belief_states
=
evaluation_utils
.
get_predictions
(
args
,
model
,
device
,
test_dataloader
)
args
,
model
,
device
,
test_dataloader
)
belief_states
,
labels
,
request_belief
,
request_labels
,
domain_belief
,
domain_labels
,
greeting_belief
,
greeting_labels
=
belief_states
belief_states
,
labels
,
request_belief
,
request_labels
,
domain_belief
,
domain_labels
,
greeting_belief
,
greeting_labels
=
belief_states
out
=
{
'
belief_states
'
:
belief_states
,
'
labels
'
:
labels
,
out
=
{
'
belief_states
'
:
belief_states
,
'
labels
'
:
labels
,
'
request_belief
'
:
request_belief
,
'
request_labels
'
:
request_labels
,
'
request_belief
'
:
request_belief
,
'
request_labels
'
:
request_labels
,
'
domain_belief
'
:
domain_belief
,
'
domain_labels
'
:
domain_labels
,
'
domain_belief
'
:
domain_belief
,
'
domain_labels
'
:
domain_labels
,
'
greeting_belief
'
:
greeting_belief
,
'
greeting_labels
'
:
greeting_labels
}
'
greeting_belief
'
:
greeting_belief
,
'
greeting_labels
'
:
greeting_labels
}
torch
.
save
(
out
,
os
.
path
.
join
(
torch
.
save
(
out
,
os
.
path
.
join
(
OUTPUT_DIR
,
'
predictions
'
,
'
test.predictions
'
))
OUTPUT_DIR
,
'
predictions
'
,
'
test.predictions
'
))
# err = [ece(belief_states[slot].reshape(-1, belief_states[slot].size(-1)), labels[slot].reshape(-1), 10)
# for slot in belief_states]
# err = max(err)
# logger.info('ECE: %f' % err)
# Calculate calibration metrics
# Calculate calibration metrics
jg
=
jg_ece
(
belief_states
,
labels
,
10
)
jg
=
jg_ece
(
belief_states
,
labels
,
10
)
logger
.
info
(
'
Joint Goal ECE: %f
'
%
jg
)
logger
.
info
(
'
Joint Goal ECE: %f
'
%
jg
)
...
@@ -298,11 +182,11 @@ def main(args=None, config=None):
...
@@ -298,11 +182,11 @@ def main(args=None, config=None):
logger
.
info
(
'
Slot presence Binary ECE: %f
'
%
jg
)
logger
.
info
(
'
Slot presence Binary ECE: %f
'
%
jg
)
jg_acc
=
0.0
jg_acc
=
0.0
padding
=
torch
.
cat
([
item
.
unsqueeze
(
-
1
)
padding
=
torch
.
cat
([
item
.
unsqueeze
(
-
1
)
for
_
,
item
in
labels
.
items
()],
-
1
).
sum
(
-
1
)
*
-
1.0
for
_
,
item
in
labels
.
items
()],
-
1
).
sum
(
-
1
)
*
-
1.0
padding
=
(
padding
==
len
(
labels
))
padding
=
(
padding
==
len
(
labels
))
padding
=
padding
.
reshape
(
-
1
)
padding
=
padding
.
reshape
(
-
1
)
for
slot
in
belief_states
:
for
slot
in
belief_states
:
args
.
accuracy_topn
=
1
topn
=
args
.
accuracy_topn
topn
=
args
.
accuracy_topn
p_
=
belief_states
[
slot
]
p_
=
belief_states
[
slot
]
gold
=
labels
[
slot
]
gold
=
labels
[
slot
]
...
@@ -317,8 +201,7 @@ def main(args=None, config=None):
...
@@ -317,8 +201,7 @@ def main(args=None, config=None):
labs
=
labs
[:,
:
topn
]
labs
=
labs
[:,
:
topn
]
else
:
else
:
labs
=
p_
.
reshape
(
-
1
,
p_
.
size
(
-
1
)).
argmax
(
dim
=-
1
).
unsqueeze
(
-
1
)
labs
=
p_
.
reshape
(
-
1
,
p_
.
size
(
-
1
)).
argmax
(
dim
=-
1
).
unsqueeze
(
-
1
)
acc
=
[
lab
in
s
for
lab
,
s
,
pad
in
zip
(
acc
=
[
lab
in
s
for
lab
,
s
,
pad
in
zip
(
gold
.
reshape
(
-
1
),
labs
,
padding
)
if
not
pad
]
gold
.
reshape
(
-
1
),
labs
,
padding
)
if
not
pad
]
acc
=
torch
.
tensor
(
acc
).
float
()
acc
=
torch
.
tensor
(
acc
).
float
()
jg_acc
+=
acc
jg_acc
+=
acc
...
@@ -337,6 +220,34 @@ def main(args=None, config=None):
...
@@ -337,6 +220,34 @@ def main(args=None, config=None):
l2
=
l2_acc
(
belief_states
,
labels
,
remove_belief
=
True
)
l2
=
l2_acc
(
belief_states
,
labels
,
remove_belief
=
True
)
logger
.
info
(
f
'
Binary Model L2 Norm Goal Accuracy:
{
l2
}
'
)
logger
.
info
(
f
'
Binary Model L2 Norm Goal Accuracy:
{
l2
}
'
)
padding
=
torch
.
cat
([
item
.
unsqueeze
(
-
1
)
for
_
,
item
in
labels
.
items
()],
-
1
).
sum
(
-
1
)
*
-
1.0
padding
=
(
padding
==
len
(
labels
))
padding
=
padding
.
reshape
(
-
1
)
tp
,
fp
,
fn
,
tn
,
n
=
0.0
,
0.0
,
0.0
,
0.0
,
0.0
for
slot
in
belief_states
:
p_
=
belief_states
[
slot
]
gold
=
labels
[
slot
].
reshape
(
-
1
)
p_
=
p_
.
reshape
(
-
1
,
p_
.
size
(
-
1
))
p_
=
p_
[
~
padding
].
argmax
(
-
1
)
gold
=
gold
[
~
padding
]
tp
+=
(
p_
==
gold
)[
gold
!=
0
].
int
().
sum
().
item
()
fp
+=
(
p_
!=
0
)[
gold
==
0
].
int
().
sum
().
item
()
fp
+=
(
p_
!=
gold
)[
gold
!=
0
].
int
().
sum
().
item
()
fp
-=
(
p_
==
0
)[
gold
!=
0
].
int
().
sum
().
item
()
fn
+=
(
p_
==
0
)[
gold
!=
0
].
int
().
sum
().
item
()
tn
+=
(
p_
==
0
)[
gold
==
0
].
int
().
sum
().
item
()
n
+=
p_
.
size
(
0
)
acc
=
(
tp
+
tn
)
/
n
prec
=
tp
/
(
tp
+
fp
)
rec
=
tp
/
(
tp
+
fn
)
f1
=
2
*
(
prec
*
rec
)
/
(
prec
+
rec
)
logger
.
info
(
f
"
Slot Accuracy:
{
acc
}
, Slot F1:
{
f1
}
, Slot Precision:
{
prec
}
, Slot Recall:
{
rec
}
"
)
for
slot
in
belief_states
:
for
slot
in
belief_states
:
p
=
belief_states
[
slot
]
p
=
belief_states
[
slot
]
p
=
p
.
reshape
(
-
1
,
p
.
size
(
-
1
))
p
=
p
.
reshape
(
-
1
,
p
.
size
(
-
1
))
...
@@ -347,7 +258,6 @@ def main(args=None, config=None):
...
@@ -347,7 +258,6 @@ def main(args=None, config=None):
l
=
labels
[
slot
].
reshape
(
-
1
)
l
=
labels
[
slot
].
reshape
(
-
1
)
l
[
l
>
0
]
=
1
l
[
l
>
0
]
=
1
labels
[
slot
]
=
l
labels
[
slot
]
=
l
f1
=
0.0
f1
=
0.0
for
slot
in
belief_states
:
for
slot
in
belief_states
:
prd
=
belief_states
[
slot
].
argmax
(
-
1
)
prd
=
belief_states
[
slot
].
argmax
(
-
1
)
...
...
This diff is collapsed.
Click to expand it.
convlab/dst/setsumbt/modeling/
calibr
ation_utils.py
→
convlab/dst/setsumbt/modeling/
evalu
ation_utils.py
+
5
−
3
View file @
c8c19090
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# Copyright 202
0
DSML Group, Heinrich Heine University, Düsseldorf
# Copyright 202
2
DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
Discriminative models calibration
"""
"""
Evaluation Utilities
"""
import
random
import
random
...
@@ -119,7 +119,9 @@ def get_predictions(args, model, device, dataloader):
...
@@ -119,7 +119,9 @@ def get_predictions(args, model, device, dataloader):
else
:
else
:
request_belief
,
request_labels
,
domain_belief
,
domain_labels
,
greeting_belief
,
greeting_labels
=
[
None
]
*
6
request_belief
,
request_labels
,
domain_belief
,
domain_labels
,
greeting_belief
,
greeting_labels
=
[
None
]
*
6
return
belief_states
,
labels
,
request_belief
,
request_labels
,
domain_belief
,
domain_labels
,
greeting_belief
,
greeting_labels
out
=
(
belief_states
,
labels
,
request_belief
,
request_labels
)
out
+=
(
domain_belief
,
domain_labels
,
greeting_belief
,
greeting_labels
)
return
out
def
normalise
(
p
):
def
normalise
(
p
):
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment