Skip to content
Snippets Groups Projects
Commit 9cf97f07 authored by zqwerty's avatar zqwerty
Browse files

add t5dst

parent fce56fd0
No related branches found
No related tags found
No related merge requests found
Showing
with 290 additions and 15 deletions
......@@ -4,6 +4,7 @@ from tqdm import tqdm
import re
from convlab2.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data
from convlab2.base_models.t5.nlu.serialization import serialize_dialogue_acts, deserialize_dialogue_acts, equal_da_seq
from convlab2.base_models.t5.dst.serialization import serialize_dialogue_state, deserialize_dialogue_state, equal_state_seq
def create_rg_data(dataset, data_dir, args):
data_by_split = load_rg_data(dataset, speaker=args.speaker)
......@@ -44,6 +45,28 @@ def create_nlu_data(dataset, data_dir, args):
with open(file_name, "w", encoding='utf-8') as f:
f.writelines(data)
def create_dst_data(dataset, data_dir, args):
data_by_split = load_dst_data(dataset, speaker=args.speaker, use_context=args.context_window_size>0, context_window_size=args.context_window_size)
data_dir = os.path.join(data_dir, args.speaker, f'context_{args.context_window_size}')
os.makedirs(data_dir, exist_ok=True)
data_splits = data_by_split.keys()
for data_split in data_splits:
data = []
for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False):
response = f"{sample['speaker']}: {sample['utterance']}"
if args.context_window_size>0:
context = ' '.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[response])
else:
context = response
state_seq = serialize_dialogue_state(sample['state'])
assert equal_state_seq(sample['state'], state_seq), print(sample['state'], state_seq, deserialize_dialogue_state(state_seq))
data.append(json.dumps({'context': context, 'state_seq': state_seq}, ensure_ascii=False)+'\n')
file_name = os.path.join(data_dir, f"{data_split}.json")
with open(file_name, "w", encoding='utf-8') as f:
f.writelines(data)
def create_goal2dialogue_data(dataset, data_dir, args):
data_by_split = dataset
os.makedirs(data_dir, exist_ok=True)
......@@ -64,7 +87,7 @@ def create_goal2dialogue_data(dataset, data_dir, args):
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="create data for seq2seq training")
parser.add_argument('--tasks', '-t', metavar='task_name', nargs='*', choices=['rg', 'nlu', 'goal2dialogue'], help='names of tasks')
parser.add_argument('--tasks', '-t', metavar='task_name', nargs='*', choices=['rg', 'nlu', 'dst', 'goal2dialogue'], help='names of tasks')
parser.add_argument('--datasets', '-d', metavar='dataset_name', nargs='*', help='names of unified datasets')
parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s)')
parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered')
......
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DST Metric"""
import datasets
from convlab2.base_models.t5.dst.serialization import deserialize_dialogue_state
# TODO: Add BibTeX citation
_CITATION = """\
"""
_DESCRIPTION = """\
Metric to evaluate text-to-text models on the dialog state tracking task.
"""
_KWARGS_DESCRIPTION = """
Calculates sequence exact match, joint goal accuracy and slot f1
Args:
predictions: list of predictions to score. Each predictions
should be a string.
references: list of reference for each prediction. Each
reference should be a string.
Returns:
seq_em: sequence exact match
accuracy: dialog state accuracy
slot_f1: slot f1
Examples:
>>> dst_metric = datasets.load_metric("dst_metric.py")
>>> predictions = ["[restaurant][price range][moderate]", "[restaurant][price range][moderate];[restaurant][food][catalan];[restaurant][area][centre]"]
>>> references = ["[restaurant][price range][moderate]", "[restaurant][price range][moderate];[restaurant][food][catalan];[attraction][area][centre]"]
>>> results = dst_metric.compute(predictions=predictions, references=references)
>>> print(results)
{'seq_em': 0.5, 'accuracy': 0.5,
'slot_f1': 0.75, 'slot_precision': 0.75, 'slot_recall': 0.75}
"""
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class DSTMetrics(datasets.Metric):
"""Metric to evaluate text-to-text models on the dialog state tracking task."""
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
# This defines the format of each prediction and reference
features=datasets.Features({
'predictions': datasets.Value('string'),
'references': datasets.Value('string'),
})
)
def _compute(self, predictions, references):
"""Returns the scores: sequence exact match, joint goal accuracy and slot f1"""
seq_em = []
acc = []
f1_metrics = {'TP':0, 'FP':0, 'FN':0}
for prediction, reference in zip(predictions, references):
seq_em.append(prediction.strip()==reference.strip())
pred_state = deserialize_dialogue_state(prediction)
gold_state = deserialize_dialogue_state(reference)
predicts = sorted(list({(domain, slot, value) for domain in pred_state for slot, value in pred_state[domain].items() if len(value)>0}))
labels = sorted(list({(domain, slot, value) for domain in gold_state for slot, value in gold_state[domain].items() if len(value)>0}))
flag = True
for ele in predicts:
if ele in labels:
f1_metrics['TP'] += 1
else:
f1_metrics['FP'] += 1
for ele in labels:
if ele not in predicts:
f1_metrics['FN'] += 1
flag &= (predicts==labels)
acc.append(flag)
TP = f1_metrics.pop('TP')
FP = f1_metrics.pop('FP')
FN = f1_metrics.pop('FN')
precision = 1.0 * TP / (TP + FP) if TP + FP else 0.
recall = 1.0 * TP / (TP + FN) if TP + FN else 0.
f1 = 2.0 * precision * recall / (precision + recall) if precision + recall else 0.
f1_metrics[f'slot_f1'] = f1
f1_metrics[f'slot_precision'] = precision
f1_metrics[f'slot_recall'] = recall
return {
"seq_em": sum(seq_em)/len(seq_em),
"accuracy": sum(acc)/len(acc),
**f1_metrics
}
import json
import os
from convlab2.util import load_dataset, load_dst_data
from convlab2.base_models.t5.dst.serialization import deserialize_state
def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
assert os.path.exists(predict_result)
dataset = load_dataset(dataset_name)
data = load_dst_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
if save_dir is None:
save_dir = os.path.dirname(predict_result)
else:
os.makedirs(save_dir, exist_ok=True)
predict_result = [deserialize_state(json.loads(x)['predictions'].strip()) for x in open(predict_result)]
for sample, prediction in zip(data, predict_result):
sample['predictions'] = {'state': prediction}
json.dump(data, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="merge predict results with original data for unified NLU evaluation")
parser.add_argument('--dataset', '-d', metavar='dataset_name', type=str, help='name of the unified dataset')
parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances')
parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result')
parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered')
parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json')
args = parser.parse_args()
print(args)
merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result)
n_gpus=4
task_name="dst"
dataset_name="multiwoz21"
speaker="user"
context_window_size=100
data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
metric_name_or_path="dst_metric.py"
metric_for_best_model="accuracy"
source_column="context"
target_column="state_seq"
truncation_side="left"
max_source_length=512
max_target_length=512
model_name_or_path="t5-small"
per_device_train_batch_size=64
per_device_eval_batch_size=64
gradient_accumulation_steps=2
lr=1e-3
num_train_epochs=10
python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --speaker ${speaker} --context_window_size ${context_window_size}
python -m torch.distributed.launch --master_port 29501 \
--nproc_per_node ${n_gpus} ../run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--do_predict \
--save_strategy epoch \
--evaluation_strategy epoch \
--load_best_model_at_end \
--predict_with_generate \
--metric_name_or_path ${metric_name_or_path} \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--debug underflow_overflow \
--adafactor \
--gradient_checkpointing
# python -m torch.distributed.launch \
# --nproc_per_node ${n_gpus} ../run_seq2seq.py \
# --task_name ${task_name} \
# --test_file ${test_file} \
# --source_column ${source_column} \
# --target_column ${target_column} \
# --max_source_length ${max_source_length} \
# --max_target_length ${max_target_length} \
# --truncation_side ${truncation_side} \
# --model_name_or_path ${output_dir} \
# --do_predict \
# --predict_with_generate \
# --metric_name_or_path ${metric_name_or_path} \
# --cache_dir ${cache_dir} \
# --output_dir ${output_dir} \
# --logging_dir ${logging_dir} \
# --overwrite_output_dir \
# --preprocessing_num_workers 4 \
# --per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
def serialize_dialogue_state(state):
state_seqs = []
for domain in state:
for slot, value in state[domain].items():
if len(value) > 0:
state_seqs.append(f'[{domain}][{slot}][{value}]')
return ';'.join(state_seqs)
def deserialize_dialogue_state(state_seq):
state = {}
if len(state_seq) == 0:
return state
state_seqs = state_seq.split('];[')
for i, state_seq in enumerate(state_seqs):
if len(state_seq) == 0:
continue
if i == 0:
if state_seq[0] == '[':
state_seq = state_seq[1:]
if i == len(state_seqs) - 1:
if state_seq[-1] == ']':
state_seq = state_seq[:-1]
s = state_seq.split('][')
if len(s) != 3:
continue
domain, slot, value = s
state.setdefault(domain, {})
state[domain][slot] = value
return state
def equal_state_seq(state, state_seq):
predict_state = deserialize_dialogue_state(state_seq)
svs = sorted([(domain, slot, value) for domain in state for slot, value in state[domain].items() if len(value)>0])
predict_svs = sorted([(domain, slot, value) for domain in predict_state for slot, value in predict_state[domain].items() if len(value)>0])
if svs != predict_svs:
return False
return True
......@@ -83,13 +83,11 @@ class NLUMetrics(datasets.Metric):
flag = True
for da_type in ['binary', 'categorical', 'non-categorical']:
if da_type == 'binary':
predicts = [(x['intent'], x['domain'], x['slot']) for x in pred_da[da_type]]
labels = [(x['intent'], x['domain'], x['slot']) for x in gold_da[da_type]]
predicts = sorted(list({(x['intent'], x['domain'], x['slot']) for x in pred_da[da_type]}))
labels = sorted(list({(x['intent'], x['domain'], x['slot']) for x in gold_da[da_type]}))
else:
predicts = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in pred_da[da_type]]
labels = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in gold_da[da_type]]
predicts = sorted(list(set(predicts)))
labels = sorted(list(set(labels)))
predicts = sorted(list({(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in pred_da[da_type]}))
labels = sorted(list({(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in gold_da[da_type]}))
for ele in predicts:
if ele in labels:
f1_metrics['overall']['TP'] += 1
......
......@@ -39,7 +39,6 @@ python -m torch.distributed.launch \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
......
......@@ -39,7 +39,6 @@ python -m torch.distributed.launch \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
......
......@@ -39,7 +39,6 @@ python -m torch.distributed.launch \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
......
......@@ -39,7 +39,6 @@ python -m torch.distributed.launch \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
......
......@@ -39,7 +39,6 @@ python -m torch.distributed.launch \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
......
......@@ -39,7 +39,6 @@ python -m torch.distributed.launch \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
......
......@@ -39,7 +39,6 @@ python -m torch.distributed.launch \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
......
......@@ -39,7 +39,6 @@ python -m torch.distributed.launch \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment