Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
E
emoUS-public
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
general
dsml
emoUS-public
Commits
649fbf4a
Commit
649fbf4a
authored
Aug 4, 2022
by
Carel van Niekerk
Browse files
Options
Downloads
Patches
Plain Diff
Add distillation setup tools script
parent
71848388
No related branches found
No related tags found
No related merge requests found
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
convlab/dst/setsumbt/dataset/unified_format.py
+35
-23
35 additions, 23 deletions
convlab/dst/setsumbt/dataset/unified_format.py
convlab/dst/setsumbt/distillation_setup.py
+250
-0
250 additions, 0 deletions
convlab/dst/setsumbt/distillation_setup.py
with
285 additions
and
23 deletions
convlab/dst/setsumbt/dataset/unified_format.py
+
35
−
23
View file @
649fbf4a
...
...
@@ -215,7 +215,9 @@ class UnifiedFormatDataset(Dataset):
max_turns
:
int
=
12
,
max_seq_len
:
int
=
64
,
train_ratio
:
float
=
1.0
,
seed
:
int
=
0
):
seed
:
int
=
0
,
data
:
dict
=
None
,
ontology
:
dict
=
None
):
"""
Args:
dataset_name (str): Name of the dataset/s to load (multiple to be seperated by +)
...
...
@@ -225,7 +227,13 @@ class UnifiedFormatDataset(Dataset):
max_seq_len (int): Maximum number of tokens in a dialogue turn
train_ratio (float): Fraction of training data to use during training
seed (int): Seed governing random order of ids for subsampling
data (dict): Dataset features for loading from dict
ontology (dict): Ontology dict for loading from dict
"""
if
data
is
not
None
:
self
.
ontology
=
ontology
self
.
features
=
data
else
:
if
'
+
'
in
dataset_name
:
dataset_args
=
[{
"
dataset_name
"
:
name
}
for
name
in
dataset_name
.
split
(
'
+
'
)]
else
:
...
...
@@ -303,6 +311,10 @@ class UnifiedFormatDataset(Dataset):
self
.
features
=
{
label
:
self
.
features
[
label
].
to
(
device
)
for
label
in
self
.
features
if
self
.
features
[
label
]
is
not
None
}
@classmethod
def
from_datadict
(
cls
,
data
:
dict
,
ontology
:
dict
):
return
cls
(
None
,
None
,
None
,
data
=
data
,
ontology
=
ontology
)
def
get_dataloader
(
dataset_name
:
str
,
set_type
:
str
,
...
...
This diff is collapsed.
Click to expand it.
convlab/dst/setsumbt/distillation_setup.py
0 → 100644
+
250
−
0
View file @
649fbf4a
# -*- coding: utf-8 -*-
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Get ensemble predictions and build distillation dataloaders
"""
from
argparse
import
ArgumentDefaultsHelpFormatter
,
ArgumentParser
import
os
import
json
import
torch
from
torch.utils.data
import
DataLoader
,
RandomSampler
,
SequentialSampler
from
tqdm
import
tqdm
from
convlab.dst.setsumbt.dataset.unified_format
import
UnifiedFormatDataset
from
convlab.dst.setsumbt.modeling
import
EnsembleSetSUMBT
from
convlab.dst.setsumbt.modeling
import
training
DEVICE
=
'
cuda
'
if
torch
.
cuda
.
is_available
()
else
'
cpu
'
def
main
():
parser
=
ArgumentParser
(
formatter_class
=
ArgumentDefaultsHelpFormatter
)
parser
.
add_argument
(
'
--model_path
'
,
type
=
str
)
parser
.
add_argument
(
'
--model_type
'
,
type
=
str
)
parser
.
add_argument
(
'
--set_type
'
,
type
=
str
)
parser
.
add_argument
(
'
--batch_size
'
,
type
=
int
)
parser
.
add_argument
(
'
--reduction
'
,
type
=
str
,
default
=
'
mean
'
)
parser
.
add_argument
(
'
--get_ensemble_distributions
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
'
--build_dataloaders
'
,
action
=
'
store_true
'
)
args
=
parser
.
parse_args
()
if
args
.
get_ensemble_distributions
:
get_ensemble_distributions
(
args
)
elif
args
.
build_dataloaders
:
path
=
os
.
path
.
join
(
args
.
model_path
,
'
dataloaders
'
,
f
'
{
args
.
set_type
}
.data
'
)
data
=
torch
.
load
(
path
)
reader
=
open
(
os
.
path
.
join
(
args
.
model_path
,
'
database
'
,
f
'
{
args
.
set_type
}
.json
'
),
'
r
'
)
ontology
=
json
.
load
(
reader
)
reader
.
close
()
loader
=
get_loader
(
data
,
ontology
,
args
.
set_type
,
args
.
batch_size
)
path
=
os
.
path
.
join
(
args
.
model_path
,
'
dataloaders
'
,
f
'
{
args
.
set_type
}
.dataloader
'
)
torch
.
save
(
loader
,
path
)
else
:
raise
NameError
(
"
NotImplemented
"
)
def
get_loader
(
data
:
dict
,
ontology
:
dict
,
set_type
:
str
=
'
train
'
,
batch_size
:
int
=
3
)
->
DataLoader
:
"""
Build dataloader from ensemble prediction data
Args:
data: Dictionary of ensemble predictions
ontology: Data ontology
set_type: Data subset (train/validation/test)
batch_size: Number of dialogues per batch
Returns:
loader: Data loader object
"""
data
=
flatten_data
(
data
)
data
=
do_label_padding
(
data
)
data
=
UnifiedFormatDataset
.
from_datadict
(
data
,
ontology
)
if
set_type
==
'
train
'
:
sampler
=
RandomSampler
(
data
)
else
:
sampler
=
SequentialSampler
(
data
)
loader
=
DataLoader
(
data
,
sampler
=
sampler
,
batch_size
=
batch_size
)
return
loader
def
do_label_padding
(
data
:
dict
)
->
dict
:
"""
Add padding to the ensemble predictions (used as labels in distillation)
Args:
data: Dictionary of ensemble predictions
Returns:
data: Padded ensemble predictions
"""
if
'
attention_mask
'
in
data
:
dialogs
,
turns
=
torch
.
where
(
data
[
'
attention_mask
'
].
sum
(
-
1
)
==
0.0
)
else
:
dialogs
,
turns
=
torch
.
where
(
data
[
'
input_ids
'
].
sum
(
-
1
)
==
0.0
)
for
key
in
data
:
if
key
not
in
[
'
input_ids
'
,
'
attention_mask
'
,
'
token_type_ids
'
]:
data
[
key
][
dialogs
,
turns
]
=
-
1
return
data
map_dict
=
{
'
belief_state
'
:
'
belief
'
,
'
greeting_act_belief
'
:
'
goodbye_belief
'
,
'
state_labels
'
:
'
labels
'
,
'
request_labels
'
:
'
request
'
,
'
domain_labels
'
:
'
active
'
,
'
greeting_labels
'
:
'
goodbye
'
}
def
flatten_data
(
data
:
dict
)
->
dict
:
"""
Map data to flattened feature format used in training
Args:
data: Ensemble prediction data
Returns:
data: Flattened ensemble prediction data
"""
data_new
=
dict
()
for
label
,
feats
in
data
.
items
():
label
=
map_dict
.
get
(
label
,
label
)
if
type
(
feats
)
==
dict
:
for
label_
,
feats_
in
feats
.
items
():
data_new
[
label
+
'
-
'
+
label_
]
=
feats_
else
:
data_new
[
label
]
=
feats
return
data_new
def
get_ensemble_distributions
(
args
):
"""
Load data and get ensemble predictions
Args:
args: Runtime arguments
"""
device
=
DEVICE
model
=
EnsembleSetSUMBT
.
from_pretrained
(
args
.
model_path
)
model
=
model
.
to
(
device
)
print
(
'
Model Loaded!
'
)
dataloader
=
os
.
path
.
join
(
args
.
model_path
,
'
dataloaders
'
,
f
'
{
args
.
set_type
}
.dataloader
'
)
database
=
os
.
path
.
join
(
args
.
model_path
,
'
database
'
,
f
'
{
args
.
set_type
}
.db
'
)
dataloader
=
torch
.
load
(
dataloader
)
database
=
torch
.
load
(
database
)
training
.
set_ontology_embeddings
(
model
,
database
)
print
(
'
Environment set up.
'
)
input_ids
=
[]
token_type_ids
=
[]
attention_mask
=
[]
state_labels
=
{
slot
:
[]
for
slot
in
model
.
informable_slot_ids
}
request_labels
=
{
slot
:
[]
for
slot
in
model
.
requestable_slot_ids
}
domain_labels
=
{
domain
:
[]
for
domain
in
model
.
domain_ids
}
greeting_labels
=
[]
is_noisy
=
[]
if
'
is_noisy
'
in
dataloader
.
dataset
.
features
else
None
belief_state
=
{
slot
:
[]
for
slot
in
model
.
informable_slot_ids
}
request_belief
=
{
slot
:
[]
for
slot
in
model
.
requestable_slot_ids
}
domain_belief
=
{
domain
:
[]
for
domain
in
model
.
domain_ids
}
greeting_act_belief
=
[]
model
.
eval
()
for
batch
in
tqdm
(
dataloader
,
desc
=
'
Batch:
'
):
ids
=
batch
[
'
input_ids
'
]
tt_ids
=
batch
[
'
token_type_ids
'
]
if
'
token_type_ids
'
in
batch
else
None
mask
=
batch
[
'
attention_mask
'
]
if
'
attention_mask
'
in
batch
else
None
if
'
is_noisy
'
in
batch
:
is_noisy
.
append
(
batch
[
'
is_noisy
'
])
input_ids
.
append
(
ids
)
token_type_ids
.
append
(
tt_ids
)
attention_mask
.
append
(
mask
)
ids
=
ids
.
to
(
device
)
tt_ids
=
tt_ids
.
to
(
device
)
if
tt_ids
is
not
None
else
None
mask
=
mask
.
to
(
device
)
if
mask
is
not
None
else
None
for
slot
in
state_labels
:
state_labels
[
slot
].
append
(
batch
[
'
labels-
'
+
slot
])
if
model
.
config
.
predict_actions
:
for
slot
in
request_labels
:
request_labels
[
slot
].
append
(
batch
[
'
request-
'
+
slot
])
for
domain
in
domain_labels
:
domain_labels
[
domain
].
append
(
batch
[
'
active-
'
+
domain
])
greeting_labels
.
append
(
batch
[
'
goodbye
'
])
with
torch
.
no_grad
():
p
,
p_req
,
p_dom
,
p_bye
,
_
=
model
(
ids
,
mask
,
tt_ids
,
reduction
=
args
.
reduction
)
for
slot
in
belief_state
:
belief_state
[
slot
].
append
(
p
[
slot
].
cpu
())
if
model
.
config
.
predict_actions
:
for
slot
in
request_belief
:
request_belief
[
slot
].
append
(
p_req
[
slot
].
cpu
())
for
domain
in
domain_belief
:
domain_belief
[
domain
].
append
(
p_dom
[
domain
].
cpu
())
greeting_act_belief
.
append
(
p_bye
.
cpu
())
input_ids
=
torch
.
cat
(
input_ids
,
0
)
if
input_ids
[
0
]
is
not
None
else
None
token_type_ids
=
torch
.
cat
(
token_type_ids
,
0
)
if
token_type_ids
[
0
]
is
not
None
else
None
attention_mask
=
torch
.
cat
(
attention_mask
,
0
)
if
attention_mask
[
0
]
is
not
None
else
None
is_noisy
=
torch
.
cat
(
is_noisy
,
0
)
if
is_noisy
is
not
None
else
None
state_labels
=
{
slot
:
torch
.
cat
(
l
,
0
)
for
slot
,
l
in
state_labels
.
items
()}
if
model
.
config
.
predict_actions
:
request_labels
=
{
slot
:
torch
.
cat
(
l
,
0
)
for
slot
,
l
in
request_labels
.
items
()}
domain_labels
=
{
domain
:
torch
.
cat
(
l
,
0
)
for
domain
,
l
in
domain_labels
.
items
()}
greeting_labels
=
torch
.
cat
(
greeting_labels
,
0
)
belief_state
=
{
slot
:
torch
.
cat
(
p
,
0
)
for
slot
,
p
in
belief_state
.
items
()}
if
model
.
config
.
predict_actions
:
request_belief
=
{
slot
:
torch
.
cat
(
p
,
0
)
for
slot
,
p
in
request_belief
.
items
()}
domain_belief
=
{
domain
:
torch
.
cat
(
p
,
0
)
for
domain
,
p
in
domain_belief
.
items
()}
greeting_act_belief
=
torch
.
cat
(
greeting_act_belief
,
0
)
data
=
{
'
input_ids
'
:
input_ids
}
if
token_type_ids
is
not
None
:
data
[
'
token_type_ids
'
]
=
token_type_ids
if
attention_mask
is
not
None
:
data
[
'
attention_mask
'
]
=
attention_mask
if
is_noisy
is
not
None
:
data
[
'
is_noisy
'
]
=
is_noisy
data
[
'
state_labels
'
]
=
state_labels
data
[
'
belief_state
'
]
=
belief_state
if
model
.
config
.
predict_actions
:
data
[
'
request_labels
'
]
=
request_labels
data
[
'
domain_labels
'
]
=
domain_labels
data
[
'
greeting_labels
'
]
=
greeting_labels
data
[
'
request_belief
'
]
=
request_belief
data
[
'
domain_belief
'
]
=
domain_belief
data
[
'
greeting_act_belief
'
]
=
greeting_act_belief
file
=
os
.
path
.
join
(
args
.
model_path
,
'
dataloaders
'
,
f
'
{
args
.
set_type
}
.data
'
)
torch
.
save
(
data
,
file
)
if
__name__
==
"
__main__
"
:
main
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment