diff --git a/README.md b/README.md index de367a3a7d6d3ab2187d4cbb4a1a10550936370a..6b26b6a25fc8c1c94801a22490e2a5b7f9833452 100755 --- a/README.md +++ b/README.md @@ -16,11 +16,12 @@ ## Updates +- **2023.2.26**: Update ConvLab on PyPI to 3.0.1 to reflect bug fixes. - **2022.11.30**: ConvLab-3 release. ## Installation -You can install ConvLab-3 in one of the following ways according to your need. Higher versions of `torch` and `transformers` may also work. +You can install ConvLab-3 in one of the following ways according to your need. We use `torch>=1.10.1,<=1.13` and `transformers>=4.17.0,<=4.24.0`. Higher versions of `torch` and `transformers` may also work. ### Git clone and pip install in development mode (Recommend) diff --git a/convlab/base_models/t5/dst/merge_predict_res.py b/convlab/base_models/t5/dst/merge_predict_res.py index f25279a87a8404faab03523a072782c1a08b738c..9b519229b5b0a941a6d4172bd1473c6c8077658a 100755 --- a/convlab/base_models/t5/dst/merge_predict_res.py +++ b/convlab/base_models/t5/dst/merge_predict_res.py @@ -4,7 +4,7 @@ from convlab.util import load_dataset, load_dst_data from convlab.base_models.t5.dst.serialization import deserialize_dialogue_state -def merge(dataset_names, speaker, save_dir, context_window_size, predict_result): +def merge(dataset_names, speaker, save_dir, context_window_size, predict_result, dial_ids_order): assert os.path.exists(predict_result) if save_dir is None: @@ -17,14 +17,18 @@ def merge(dataset_names, speaker, save_dir, context_window_size, predict_result) i = 0 for dataset_name in dataset_names.split('+'): print(dataset_name) - dataset = load_dataset(dataset_name, args.dial_ids_order) + single = [] + dataset = load_dataset(dataset_name, dial_ids_order) data = load_dst_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test'] for sample in data: sample['predictions'] = {'state': predict_result[i]} i += 1 + single.append(sample) merged.append(sample) + json.dump(single, open(os.path.join(save_dir, f'{dataset_name}_predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + json.dump(merged, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) @@ -35,8 +39,8 @@ if __name__ == '__main__': parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances') parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result') parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered') - parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json') + parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file test_generated_predictions.json') parser.add_argument('--dial_ids_order', '-o', type=int, default=None, help='which data order is used for experiments') args = parser.parse_args() print(args) - merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result) + merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result, args.dial_ids_order) diff --git a/convlab/base_models/t5/dst/run_dst.sh b/convlab/base_models/t5/dst/run_dst.sh index 05975400bd1ca901e1058dc80587e2cce0b0f1bb..0e2b94969049bed448cba89a76d866c58cdcc775 100644 --- a/convlab/base_models/t5/dst/run_dst.sh +++ b/convlab/base_models/t5/dst/run_dst.sh @@ -1,6 +1,6 @@ n_gpus=1 task_name="dst" -dataset_name=$1 +dataset_name=crosswoz speaker="user" context_window_size=100 data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}" @@ -17,10 +17,10 @@ target_column="state_seq" truncation_side="left" max_source_length=1024 max_target_length=512 -model_name_or_path="t5-small" -per_device_train_batch_size=64 -per_device_eval_batch_size=64 -gradient_accumulation_steps=2 +model_name_or_path="/data/zhuqi/pre-trained-models/mt5-small" +per_device_train_batch_size=16 +per_device_eval_batch_size=16 +gradient_accumulation_steps=4 lr=1e-3 num_train_epochs=10 @@ -80,6 +80,6 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json python ../../../dst/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/dst/run_dst_fewshot.sh b/convlab/base_models/t5/dst/run_dst_fewshot.sh index 4acd605706752c67d1f1df3b5fa04df13d2e46ad..2e2e999811149acc7926c7813cdad769ca1d8fbc 100644 --- a/convlab/base_models/t5/dst/run_dst_fewshot.sh +++ b/convlab/base_models/t5/dst/run_dst_fewshot.sh @@ -82,6 +82,6 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order} +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json -o ${dial_ids_order} python ../../../dst/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/dst/run_dst_multitask.sh b/convlab/base_models/t5/dst/run_dst_multitask.sh index aefb1d5200db292d0e68e7d39498dfbd182d1fa0..7846e049aecf9302063517533e2810aaea9d339c 100644 --- a/convlab/base_models/t5/dst/run_dst_multitask.sh +++ b/convlab/base_models/t5/dst/run_dst_multitask.sh @@ -30,7 +30,7 @@ mkdir -p ${data_dir} for name in ${names}; do echo "preprocessing ${name}" - # python ../create_data.py -t ${task_name} -d ${name} -s ${speaker} -c ${context_window_size} + python ../create_data.py -t ${task_name} -d ${name} -s ${speaker} -c ${context_window_size} done python merge_data.py $(echo ${dataset_name} | tr "+" " ") @@ -89,6 +89,10 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json -python ../../../dst/evaluate_unified_datasets.py -p ${output_dir}/predictions.json +for name in ${names}; +do + echo "evaluating ${name}" + python ../../../dst/evaluate_unified_datasets.py -p ${output_dir}/${name}_predictions.json +done \ No newline at end of file diff --git a/convlab/base_models/t5/nlg/merge_predict_res.py b/convlab/base_models/t5/nlg/merge_predict_res.py index 7d2995d84737378958a54765f3efc5f996f112e3..a4d1b34155fda3826b6aaa100ddf7a8f5c37c49f 100755 --- a/convlab/base_models/t5/nlg/merge_predict_res.py +++ b/convlab/base_models/t5/nlg/merge_predict_res.py @@ -3,7 +3,7 @@ import os from convlab.util import load_dataset, load_nlg_data -def merge(dataset_names, speaker, save_dir, context_window_size, predict_result): +def merge(dataset_names, speaker, save_dir, context_window_size, predict_result, dial_ids_order): assert os.path.exists(predict_result) if save_dir is None: @@ -16,7 +16,8 @@ def merge(dataset_names, speaker, save_dir, context_window_size, predict_result) i = 0 for dataset_name in dataset_names.split('+'): print(dataset_name) - dataset = load_dataset(dataset_name, args.dial_ids_order) + single = [] + dataset = load_dataset(dataset_name, dial_ids_order) data = load_nlg_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test'] for sample in data: @@ -24,8 +25,11 @@ def merge(dataset_names, speaker, save_dir, context_window_size, predict_result) continue sample['predictions'] = {'utterance': predict_result[i]} i += 1 + single.append(sample) merged.append(sample) + json.dump(single, open(os.path.join(save_dir, f'{dataset_name}_predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + json.dump(merged, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) @@ -36,8 +40,8 @@ if __name__ == '__main__': parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances') parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result') parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered') - parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json') + parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file test_generated_predictions.json') parser.add_argument('--dial_ids_order', '-o', type=int, default=None, help='which data order is used for experiments') args = parser.parse_args() print(args) - merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result) + merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result, args.dial_ids_order) diff --git a/convlab/base_models/t5/nlg/run_nlg.sh b/convlab/base_models/t5/nlg/run_nlg.sh index 0b5fa390dcaf98b098abc17f18026994ee54702c..718dca4a2e344fc903e20f48e75cccd46b7821ce 100644 --- a/convlab/base_models/t5/nlg/run_nlg.sh +++ b/convlab/base_models/t5/nlg/run_nlg.sh @@ -1,8 +1,8 @@ n_gpus=1 task_name="nlg" -dataset_name=$1 -speaker="system" -context_window_size=$2 +dataset_name=crosswoz +speaker="all" +context_window_size=0 data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}" output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}" cache_dir="../cache" @@ -17,10 +17,10 @@ target_column="response" truncation_side="left" max_source_length=512 max_target_length=512 -model_name_or_path="t5-small" -per_device_train_batch_size=128 -per_device_eval_batch_size=64 -gradient_accumulation_steps=4 +model_name_or_path="/data/zhuqi/pre-trained-models/mt5-small" +per_device_train_batch_size=32 +per_device_eval_batch_size=16 +gradient_accumulation_steps=8 lr=1e-3 num_train_epochs=10 @@ -80,6 +80,6 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json python ../../../nlg/evaluate_unified_datasets.py -p ${output_dir}/predictions.json --dataset_name ${dataset_name} diff --git a/convlab/base_models/t5/nlg/run_nlg_fewshot.sh b/convlab/base_models/t5/nlg/run_nlg_fewshot.sh index 61e50cdaa094b301660d38f74fcf8420424a7d3f..17f110a0da83dd2cd4ef74bff0b9f8f43a8637a1 100644 --- a/convlab/base_models/t5/nlg/run_nlg_fewshot.sh +++ b/convlab/base_models/t5/nlg/run_nlg_fewshot.sh @@ -83,6 +83,6 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order} +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json -o ${dial_ids_order} python ../../../nlg/evaluate_unified_datasets.py -p ${output_dir}/predictions.json --dataset_name ${dataset_name} diff --git a/convlab/base_models/t5/nlg/run_nlg_multitask.sh b/convlab/base_models/t5/nlg/run_nlg_multitask.sh index dec894aab37a37ba7923d60431fb22ef5ac4d6b6..9556bdcc0d55e4e520c371386438f9a99c71280c 100644 --- a/convlab/base_models/t5/nlg/run_nlg_multitask.sh +++ b/convlab/base_models/t5/nlg/run_nlg_multitask.sh @@ -89,6 +89,10 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json -# python ../../../nlg/evaluate_unified_datasets.py -p ${output_dir}/predictions.json --dataset_name ${dataset_name} +for name in ${names}; +do + echo "evaluating ${name}" + python ../../../nlg/evaluate_unified_datasets.py -p ${output_dir}/${name}_predictions.json --dataset_name ${name} +done \ No newline at end of file diff --git a/convlab/base_models/t5/nlu/merge_predict_res.py b/convlab/base_models/t5/nlu/merge_predict_res.py index e247160769f7e5b0c9445b38e4dc2a5caa567fd0..8c522063f98e38773841554ddddfc9f476b18049 100755 --- a/convlab/base_models/t5/nlu/merge_predict_res.py +++ b/convlab/base_models/t5/nlu/merge_predict_res.py @@ -4,7 +4,7 @@ from convlab.util import load_dataset, load_nlu_data from convlab.base_models.t5.nlu.serialization import deserialize_dialogue_acts -def merge(dataset_names, speaker, save_dir, context_window_size, predict_result): +def merge(dataset_names, speaker, save_dir, context_window_size, predict_result, dial_ids_order): assert os.path.exists(predict_result) if save_dir is None: @@ -17,14 +17,18 @@ def merge(dataset_names, speaker, save_dir, context_window_size, predict_result) i = 0 for dataset_name in dataset_names.split('+'): print(dataset_name) - dataset = load_dataset(dataset_name, args.dial_ids_order) + single = [] + dataset = load_dataset(dataset_name, dial_ids_order) data = load_nlu_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test'] for sample in data: sample['predictions'] = {'dialogue_acts': predict_result[i]} i += 1 + single.append(sample) merged.append(sample) + json.dump(single, open(os.path.join(save_dir, f'{dataset_name}_predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + json.dump(merged, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) @@ -35,8 +39,8 @@ if __name__ == '__main__': parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances') parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result') parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered') - parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json') + parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file test_generated_predictions.json') parser.add_argument('--dial_ids_order', '-o', type=int, default=None, help='which data order is used for experiments') args = parser.parse_args() print(args) - merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result) + merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result, args.dial_ids_order) diff --git a/convlab/base_models/t5/nlu/run_nlu.sh b/convlab/base_models/t5/nlu/run_nlu.sh index b81b04c0f360fe55c25e55f85ff8ceac3578a99d..cf668b5d32179747819ad3dfba04d1fda954acec 100644 --- a/convlab/base_models/t5/nlu/run_nlu.sh +++ b/convlab/base_models/t5/nlu/run_nlu.sh @@ -1,8 +1,8 @@ n_gpus=1 task_name="nlu" -dataset_name=$1 -speaker="user" -context_window_size=$2 +dataset_name=crosswoz +speaker="all" +context_window_size=0 data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}" output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}" cache_dir="../cache" @@ -17,10 +17,10 @@ target_column="dialogue_acts_seq" truncation_side="left" max_source_length=512 max_target_length=512 -model_name_or_path="t5-small" -per_device_train_batch_size=128 -per_device_eval_batch_size=64 -gradient_accumulation_steps=2 +model_name_or_path="/data/zhuqi/pre-trained-models/mt5-small" +per_device_train_batch_size=16 +per_device_eval_batch_size=16 +gradient_accumulation_steps=16 lr=1e-3 num_train_epochs=10 @@ -80,6 +80,6 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/nlu/run_nlu_fewshot.sh b/convlab/base_models/t5/nlu/run_nlu_fewshot.sh index a966310a5bea242db413dda7b9ca12bcbda0ae43..5f04579f64c097dff1d4da50fa00231354c55d7d 100644 --- a/convlab/base_models/t5/nlu/run_nlu_fewshot.sh +++ b/convlab/base_models/t5/nlu/run_nlu_fewshot.sh @@ -83,6 +83,6 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order} +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json -o ${dial_ids_order} python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/nlu/run_nlu_multitask.sh b/convlab/base_models/t5/nlu/run_nlu_multitask.sh index b91f21e3f02270ff2f1dfa42fe8baa8f16a20acc..4e29168007763e44817b8c8215705864d0131a8f 100644 --- a/convlab/base_models/t5/nlu/run_nlu_multitask.sh +++ b/convlab/base_models/t5/nlu/run_nlu_multitask.sh @@ -89,6 +89,10 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json -python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json +for name in ${names}; +do + echo "evaluating ${name}" + python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/${name}_predictions.json +done \ No newline at end of file diff --git a/convlab/base_models/t5/nlu/run_retnlu.sh b/convlab/base_models/t5/nlu/run_retnlu.sh index fd44e063dc84da86e4f77ead69b0e329ac0cc7d1..ede928abd601dc37ae48536dc2ad6229c7a4b556 100644 --- a/convlab/base_models/t5/nlu/run_retnlu.sh +++ b/convlab/base_models/t5/nlu/run_retnlu.sh @@ -81,6 +81,6 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/nlu/run_retnlu_fewshot.sh b/convlab/base_models/t5/nlu/run_retnlu_fewshot.sh index e778c80bdc844dfea732421e9234e8965e20d987..a3bfbfaac21a5b195dcd89572aad386c7b82e612 100644 --- a/convlab/base_models/t5/nlu/run_retnlu_fewshot.sh +++ b/convlab/base_models/t5/nlu/run_retnlu_fewshot.sh @@ -84,6 +84,6 @@ num_train_epochs=100 # --optim adafactor \ # --gradient_checkpointing -# python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order} +# python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json -o ${dial_ids_order} python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/nlu/run_retnlu_in_context.sh b/convlab/base_models/t5/nlu/run_retnlu_in_context.sh index 775b4b06ed35f82610466ca96e518e95eb9b86f8..5c4a091d533173691d0686c1a508ded7bee68737 100644 --- a/convlab/base_models/t5/nlu/run_retnlu_in_context.sh +++ b/convlab/base_models/t5/nlu/run_retnlu_in_context.sh @@ -81,6 +81,6 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/nlu/run_retnlu_in_context_fewshot.sh b/convlab/base_models/t5/nlu/run_retnlu_in_context_fewshot.sh index 913ef7cbad5fae0b3092c29fe0cd5f44604c333d..3a6c4ce0484c1b5366b9d081483fd3c4c9e5c247 100644 --- a/convlab/base_models/t5/nlu/run_retnlu_in_context_fewshot.sh +++ b/convlab/base_models/t5/nlu/run_retnlu_in_context_fewshot.sh @@ -84,6 +84,6 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order} +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json -o ${dial_ids_order} python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/run_seq2seq.py b/convlab/base_models/t5/run_seq2seq.py index 5fa921f0d4c855dc17b7f3b5d1daa8cc404f957c..bdf1b10f21c5f9442efc0e5381e6899966fc8872 100644 --- a/convlab/base_models/t5/run_seq2seq.py +++ b/convlab/base_models/t5/run_seq2seq.py @@ -37,6 +37,8 @@ from transformers import ( AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, + T5ForConditionalGeneration, + T5Tokenizer, DataCollatorForSeq2Seq, HfArgumentParser, EarlyStoppingCallback, @@ -358,22 +360,40 @@ def main(): revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) - tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - use_fast=model_args.use_fast_tokenizer, - truncation_side=model_args.truncation_side, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - ) - model = AutoModelForSeq2SeqLM.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - ) + try: + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + truncation_side=model_args.truncation_side, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + model = AutoModelForSeq2SeqLM.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + except: + tokenizer = T5Tokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + truncation_side=model_args.truncation_side, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + model = T5ForConditionalGeneration.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) model.resize_token_embeddings(len(tokenizer)) @@ -612,16 +632,17 @@ def main(): # Predict if training_args.do_predict: - logger.info("*** Predict ***") - predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict") + file_prefix = os.path.splitext(os.path.basename(data_args.test_file))[0] + logger.info(f"*** Predict {file_prefix}***") + predict_results = trainer.predict(predict_dataset, metric_key_prefix=file_prefix) metrics = predict_results.metrics max_predict_samples = ( data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) ) - metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) + metrics[f"{file_prefix}_samples"] = min(max_predict_samples, len(predict_dataset)) - trainer.log_metrics("predict", metrics) - trainer.save_metrics("predict", metrics) + trainer.log_metrics(file_prefix, metrics) + trainer.save_metrics(file_prefix, metrics) if trainer.is_world_process_zero(): if training_args.predict_with_generate: @@ -629,10 +650,13 @@ def main(): predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True ) predictions = [pred.strip() for pred in predictions] - output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.json") + output_prediction_file = os.path.join(training_args.output_dir, f"{file_prefix}_generated_predictions.json") with open(output_prediction_file, "w", encoding='utf-8') as writer: - for sample, pred in zip(raw_datasets["test"], predictions): - sample["predictions"] = pred + for idx, sample in enumerate(raw_datasets["test"]): + if training_args.num_return_sequences > 1: + sample["predictions"] = predictions[idx*training_args.num_return_sequences:(idx+1)*training_args.num_return_sequences] + else: + sample["predictions"] = predictions[idx] writer.write(json.dumps(sample, ensure_ascii=False)+'\n') kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": data_args.task_name} diff --git a/setup.py b/setup.py index c2cae13c857561899c9b71f68e236f30d705f5ce..0a94fb0e50cbc36f96b4b3102690656bfa10f2e1 100755 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ from setuptools import setup, find_packages setup( name='convlab', - version='3.0.0', + version='3.0.1', packages=find_packages(), license='Apache', description='An Open-source Dialog System Toolkit',