From 6daadd33e8c37dd1d12c3c895a44a76d4048a9f5 Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Sun, 10 Apr 2022 13:15:37 +0800 Subject: [PATCH] fix import bug in dst/merge_predict_res.py --- convlab2/base_models/t5/dst/dst_metric.py | 4 +- .../base_models/t5/dst/merge_predict_res.py | 4 +- convlab2/base_models/t5/dst/run_multiwoz21.sh | 43 +++++++++---------- convlab2/base_models/t5/rg/run_rg.sh | 28 ++++++------ 4 files changed, 39 insertions(+), 40 deletions(-) diff --git a/convlab2/base_models/t5/dst/dst_metric.py b/convlab2/base_models/t5/dst/dst_metric.py index 8a4f73b0..aedef34d 100644 --- a/convlab2/base_models/t5/dst/dst_metric.py +++ b/convlab2/base_models/t5/dst/dst_metric.py @@ -75,8 +75,8 @@ class DSTMetrics(datasets.Metric): pred_state = deserialize_dialogue_state(prediction) gold_state = deserialize_dialogue_state(reference) - predicts = sorted(list({(domain, slot, value) for domain in pred_state for slot, value in pred_state[domain].items() if len(value)>0})) - labels = sorted(list({(domain, slot, value) for domain in gold_state for slot, value in gold_state[domain].items() if len(value)>0})) + predicts = sorted(list({(domain, slot, ''.join(value.split()).lower()) for domain in pred_state for slot, value in pred_state[domain].items() if len(value)>0})) + labels = sorted(list({(domain, slot, ''.join(value.split()).lower()) for domain in gold_state for slot, value in gold_state[domain].items() if len(value)>0})) flag = True for ele in predicts: diff --git a/convlab2/base_models/t5/dst/merge_predict_res.py b/convlab2/base_models/t5/dst/merge_predict_res.py index ebdada8a..0a80ee80 100755 --- a/convlab2/base_models/t5/dst/merge_predict_res.py +++ b/convlab2/base_models/t5/dst/merge_predict_res.py @@ -1,7 +1,7 @@ import json import os from convlab2.util import load_dataset, load_dst_data -from convlab2.base_models.t5.dst.serialization import deserialize_state +from convlab2.base_models.t5.dst.serialization import deserialize_dialogue_state def merge(dataset_name, speaker, save_dir, context_window_size, predict_result): @@ -13,7 +13,7 @@ def merge(dataset_name, speaker, save_dir, context_window_size, predict_result): save_dir = os.path.dirname(predict_result) else: os.makedirs(save_dir, exist_ok=True) - predict_result = [deserialize_state(json.loads(x)['predictions'].strip()) for x in open(predict_result)] + predict_result = [deserialize_dialogue_state(json.loads(x)['predictions'].strip()) for x in open(predict_result)] for sample, prediction in zip(data, predict_result): sample['predictions'] = {'state': prediction} diff --git a/convlab2/base_models/t5/dst/run_multiwoz21.sh b/convlab2/base_models/t5/dst/run_multiwoz21.sh index e031be48..e7573e95 100644 --- a/convlab2/base_models/t5/dst/run_multiwoz21.sh +++ b/convlab2/base_models/t5/dst/run_multiwoz21.sh @@ -26,7 +26,7 @@ num_train_epochs=10 python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --speaker ${speaker} --context_window_size ${context_window_size} -python -m torch.distributed.launch --master_port 29501 \ +python -m torch.distributed.launch \ --nproc_per_node ${n_gpus} ../run_seq2seq.py \ --task_name ${task_name} \ --train_file ${train_file} \ @@ -43,8 +43,7 @@ python -m torch.distributed.launch --master_port 29501 \ --do_predict \ --save_strategy epoch \ --evaluation_strategy epoch \ - --load_best_model_at_end \ - --predict_with_generate \ + --prediction_loss_only \ --metric_name_or_path ${metric_name_or_path} \ --cache_dir ${cache_dir} \ --output_dir ${output_dir} \ @@ -60,24 +59,24 @@ python -m torch.distributed.launch --master_port 29501 \ --adafactor \ --gradient_checkpointing -# python -m torch.distributed.launch \ -# --nproc_per_node ${n_gpus} ../run_seq2seq.py \ -# --task_name ${task_name} \ -# --test_file ${test_file} \ -# --source_column ${source_column} \ -# --target_column ${target_column} \ -# --max_source_length ${max_source_length} \ -# --max_target_length ${max_target_length} \ -# --truncation_side ${truncation_side} \ -# --model_name_or_path ${output_dir} \ -# --do_predict \ -# --predict_with_generate \ -# --metric_name_or_path ${metric_name_or_path} \ -# --cache_dir ${cache_dir} \ -# --output_dir ${output_dir} \ -# --logging_dir ${logging_dir} \ -# --overwrite_output_dir \ -# --preprocessing_num_workers 4 \ -# --per_device_eval_batch_size ${per_device_eval_batch_size} \ +python -m torch.distributed.launch \ + --nproc_per_node ${n_gpus} ../run_seq2seq.py \ + --task_name ${task_name} \ + --test_file ${test_file} \ + --source_column ${source_column} \ + --target_column ${target_column} \ + --max_source_length ${max_source_length} \ + --max_target_length ${max_target_length} \ + --truncation_side ${truncation_side} \ + --model_name_or_path ${output_dir} \ + --do_predict \ + --predict_with_generate \ + --metric_name_or_path ${metric_name_or_path} \ + --cache_dir ${cache_dir} \ + --output_dir ${output_dir} \ + --logging_dir ${logging_dir} \ + --overwrite_output_dir \ + --preprocessing_num_workers 4 \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json diff --git a/convlab2/base_models/t5/rg/run_rg.sh b/convlab2/base_models/t5/rg/run_rg.sh index 55accadf..6fcffca2 100644 --- a/convlab2/base_models/t5/rg/run_rg.sh +++ b/convlab2/base_models/t5/rg/run_rg.sh @@ -16,24 +16,24 @@ truncation_side="left" max_source_length=512 max_target_length=128 model_name_or_path="t5-small" -per_device_train_batch_size=32 +per_device_train_batch_size=128 per_device_eval_batch_size=128 gradient_accumulation_steps=4 lr=1e-3 num_train_epochs=5 -# names=$(echo ${dataset_name} | tr "+" "\n") -# mkdir -p ${data_dir} -# for name in ${names}; -# do -# echo "preprocessing ${name}" -# python ../create_data.py --tasks ${task_name} --datasets ${name} --speaker ${speaker} -# if [ "${name}" != "${dataset_name}" ]; then -# cat "data/${task_name}/${name}/${speaker}/train.json" >> ${train_file} -# cat "data/${task_name}/${name}/${speaker}/validation.json" >> ${validation_file} -# cat "data/${task_name}/${name}/${speaker}/test.json" >> ${test_file} -# fi -# done +names=$(echo ${dataset_name} | tr "+" "\n") +mkdir -p ${data_dir} +for name in ${names}; +do + echo "preprocessing ${name}" + python ../create_data.py --tasks ${task_name} --datasets ${name} --speaker ${speaker} + if [ "${name}" != "${dataset_name}" ]; then + cat "data/${task_name}/${name}/${speaker}/train.json" >> ${train_file} + cat "data/${task_name}/${name}/${speaker}/validation.json" >> ${validation_file} + cat "data/${task_name}/${name}/${speaker}/test.json" >> ${test_file} + fi +done python -m torch.distributed.launch \ --nproc_per_node ${n_gpus} ../run_seq2seq.py \ @@ -53,7 +53,7 @@ python -m torch.distributed.launch \ --save_strategy epoch \ --evaluation_strategy epoch \ --load_best_model_at_end \ - --predict_with_generate \ + --prediction_loss_only \ --cache_dir ${cache_dir} \ --output_dir ${output_dir} \ --logging_dir ${logging_dir} \ -- GitLab