From 61b6c665b0a3814fbd43e473140d13daf432289f Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Fri, 8 Jul 2022 11:27:27 +0800 Subject: [PATCH] t5dst multitask --- convlab/base_models/t5/dst/merge_data.py | 21 +++++ .../base_models/t5/dst/merge_predict_res.py | 19 ++-- convlab/base_models/t5/dst/run_dst.sh | 2 +- .../base_models/t5/dst/run_dst_multitask.sh | 94 +++++++++++++++++++ 4 files changed, 129 insertions(+), 7 deletions(-) create mode 100644 convlab/base_models/t5/dst/merge_data.py create mode 100644 convlab/base_models/t5/dst/run_dst_multitask.sh diff --git a/convlab/base_models/t5/dst/merge_data.py b/convlab/base_models/t5/dst/merge_data.py new file mode 100644 index 00000000..7b76cdcd --- /dev/null +++ b/convlab/base_models/t5/dst/merge_data.py @@ -0,0 +1,21 @@ +import json +import os +import sys + +if __name__ == '__main__': + merged_data = {'train': [], 'validation': [], 'test': []} + print(sys.argv) + for dataset_name in sys.argv[1:]: + data_dir = os.path.join('data/dst', dataset_name, 'user/context_100') + for data_split in merged_data: + with open(os.path.join(data_dir, f'{data_split}.json'), 'r') as f: + for line in f: + item = json.loads(line) + item['context'] = f"{dataset_name}: {item['context']}" + merged_data[data_split].append(item) + for data_split in merged_data: + data_dir = os.path.join('data/dst', '+'.join(sys.argv[1:]), 'user/context_100') + os.makedirs(data_dir, exist_ok=True) + with open(os.path.join(data_dir, f'{data_split}.json'), 'w') as f: + for item in merged_data[data_split]: + f.write(json.dumps(item)+'\n') diff --git a/convlab/base_models/t5/dst/merge_predict_res.py b/convlab/base_models/t5/dst/merge_predict_res.py index 6d21d07c..f25279a8 100755 --- a/convlab/base_models/t5/dst/merge_predict_res.py +++ b/convlab/base_models/t5/dst/merge_predict_res.py @@ -4,10 +4,8 @@ from convlab.util import load_dataset, load_dst_data from convlab.base_models.t5.dst.serialization import deserialize_dialogue_state -def merge(dataset_name, speaker, save_dir, context_window_size, predict_result): +def merge(dataset_names, speaker, save_dir, context_window_size, predict_result): assert os.path.exists(predict_result) - dataset = load_dataset(dataset_name, args.dial_ids_order) - data = load_dst_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test'] if save_dir is None: save_dir = os.path.dirname(predict_result) @@ -15,10 +13,19 @@ def merge(dataset_name, speaker, save_dir, context_window_size, predict_result): os.makedirs(save_dir, exist_ok=True) predict_result = [deserialize_dialogue_state(json.loads(x)['predictions'].strip()) for x in open(predict_result)] - for sample, prediction in zip(data, predict_result): - sample['predictions'] = {'state': prediction} + merged = [] + i = 0 + for dataset_name in dataset_names.split('+'): + print(dataset_name) + dataset = load_dataset(dataset_name, args.dial_ids_order) + data = load_dst_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test'] + + for sample in data: + sample['predictions'] = {'state': predict_result[i]} + i += 1 + merged.append(sample) - json.dump(data, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + json.dump(merged, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) if __name__ == '__main__': diff --git a/convlab/base_models/t5/dst/run_dst.sh b/convlab/base_models/t5/dst/run_dst.sh index 2dfc622d..0704ebf9 100644 --- a/convlab/base_models/t5/dst/run_dst.sh +++ b/convlab/base_models/t5/dst/run_dst.sh @@ -40,7 +40,7 @@ python ../run_seq2seq.py \ --do_eval \ --save_strategy epoch \ --evaluation_strategy epoch \ - --save_total_limit 3 \ + --save_total_limit 1 \ --prediction_loss_only \ --cache_dir ${cache_dir} \ --output_dir ${output_dir} \ diff --git a/convlab/base_models/t5/dst/run_dst_multitask.sh b/convlab/base_models/t5/dst/run_dst_multitask.sh new file mode 100644 index 00000000..0f3b60a6 --- /dev/null +++ b/convlab/base_models/t5/dst/run_dst_multitask.sh @@ -0,0 +1,94 @@ +n_gpus=1 +task_name="dst" +dataset_name="sgd+tm1+tm2+tm3+multiwoz21" +speaker="user" +context_window_size=100 +data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}" +output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}" +cache_dir="../cache" +logging_dir="${output_dir}/runs" +train_file="${data_dir}/train.json" +validation_file="${data_dir}/validation.json" +test_file="${data_dir}/test.json" +metric_name_or_path="dst_metric.py" +metric_for_best_model="accuracy" +source_column="context" +target_column="state_seq" +truncation_side="left" +max_source_length=1024 +max_target_length=512 +model_name_or_path="t5-small" +per_device_train_batch_size=64 +per_device_eval_batch_size=64 +gradient_accumulation_steps=2 +lr=1e-3 +num_train_epochs=10 + +names=$(echo ${dataset_name} | tr "+" "\n") +rm -r ${data_dir} +mkdir -p ${data_dir} +for name in ${names}; +do + echo "preprocessing ${name}" + # python ../create_data.py -t ${task_name} -d ${name} -s ${speaker} -c ${context_window_size} +done + +python merge_data.py $(echo ${dataset_name} | tr "+" " ") + +python ../run_seq2seq.py \ + --task_name ${task_name} \ + --train_file ${train_file} \ + --validation_file ${validation_file} \ + --source_column ${source_column} \ + --target_column ${target_column} \ + --max_source_length ${max_source_length} \ + --max_target_length ${max_target_length} \ + --truncation_side ${truncation_side} \ + --model_name_or_path ${model_name_or_path} \ + --do_train \ + --do_eval \ + --save_strategy epoch \ + --evaluation_strategy epoch \ + --save_total_limit 1 \ + --prediction_loss_only \ + --cache_dir ${cache_dir} \ + --output_dir ${output_dir} \ + --logging_dir ${logging_dir} \ + --overwrite_output_dir \ + --preprocessing_num_workers 4 \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --learning_rate ${lr} \ + --num_train_epochs ${num_train_epochs} \ + --adafactor \ + --gradient_checkpointing + +python ../run_seq2seq.py \ + --task_name ${task_name} \ + --test_file ${test_file} \ + --source_column ${source_column} \ + --target_column ${target_column} \ + --max_source_length ${max_source_length} \ + --max_target_length ${max_target_length} \ + --truncation_side ${truncation_side} \ + --model_name_or_path ${output_dir} \ + --do_predict \ + --predict_with_generate \ + --metric_name_or_path ${metric_name_or_path} \ + --cache_dir ${cache_dir} \ + --output_dir ${output_dir} \ + --logging_dir ${logging_dir} \ + --overwrite_output_dir \ + --preprocessing_num_workers 4 \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --learning_rate ${lr} \ + --num_train_epochs ${num_train_epochs} \ + --adafactor \ + --gradient_checkpointing + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json + +python ../../../dst/evaluate_unified_datasets.py -p ${output_dir}/predictions.json -- GitLab