From a26a42b2b361f5234fd43bc9a6fe0cca9951a74b Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Fri, 29 Apr 2022 17:21:07 +0800 Subject: [PATCH] update run_seq2seq: add early stopping. update dst script to support fewshot learning --- .../base_models/t5/dst/merge_predict_res.py | 3 +- convlab2/base_models/t5/dst/run_dst.sh | 20 +++--- .../base_models/t5/dst/run_dst_pretrain.sh | 67 +++++++++++++++++++ convlab2/base_models/t5/run_seq2seq.py | 6 ++ 4 files changed, 87 insertions(+), 9 deletions(-) create mode 100644 convlab2/base_models/t5/dst/run_dst_pretrain.sh diff --git a/convlab2/base_models/t5/dst/merge_predict_res.py b/convlab2/base_models/t5/dst/merge_predict_res.py index 0a80ee80..9b942260 100755 --- a/convlab2/base_models/t5/dst/merge_predict_res.py +++ b/convlab2/base_models/t5/dst/merge_predict_res.py @@ -6,7 +6,7 @@ from convlab2.base_models.t5.dst.serialization import deserialize_dialogue_state def merge(dataset_name, speaker, save_dir, context_window_size, predict_result): assert os.path.exists(predict_result) - dataset = load_dataset(dataset_name) + dataset = load_dataset(dataset_name, args.dial_ids_order) data = load_dst_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test'] if save_dir is None: @@ -29,6 +29,7 @@ if __name__ == '__main__': parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result') parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered') parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json') + parser.add_argument('--dial_ids_order', '-o', type=int, default=None, help='which data order is used for experiments') args = parser.parse_args() print(args) merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result) diff --git a/convlab2/base_models/t5/dst/run_dst.sh b/convlab2/base_models/t5/dst/run_dst.sh index 7ee6041e..c678005e 100644 --- a/convlab2/base_models/t5/dst/run_dst.sh +++ b/convlab2/base_models/t5/dst/run_dst.sh @@ -24,14 +24,12 @@ gradient_accumulation_steps=2 lr=1e-3 num_train_epochs=10 -python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} -l t5-small +python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} -python -m torch.distributed.launch \ - --nproc_per_node ${n_gpus} ../run_seq2seq.py \ +python ../run_seq2seq.py \ --task_name ${task_name} \ --train_file ${train_file} \ --validation_file ${validation_file} \ - --test_file ${test_file} \ --source_column ${source_column} \ --target_column ${target_column} \ --max_source_length ${max_source_length} \ @@ -40,9 +38,9 @@ python -m torch.distributed.launch \ --model_name_or_path ${model_name_or_path} \ --do_train \ --do_eval \ - --do_predict \ --save_strategy epoch \ --evaluation_strategy epoch \ + --save_total_limit 3 \ --prediction_loss_only \ --cache_dir ${cache_dir} \ --output_dir ${output_dir} \ @@ -58,8 +56,7 @@ python -m torch.distributed.launch \ --adafactor \ --gradient_checkpointing -python -m torch.distributed.launch \ - --nproc_per_node ${n_gpus} ../run_seq2seq.py \ +python ../run_seq2seq.py \ --task_name ${task_name} \ --test_file ${test_file} \ --source_column ${source_column} \ @@ -76,7 +73,14 @@ python -m torch.distributed.launch \ --logging_dir ${logging_dir} \ --overwrite_output_dir \ --preprocessing_num_workers 4 \ - --per_device_eval_batch_size ${per_device_eval_batch_size} + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --learning_rate ${lr} \ + --num_train_epochs ${num_train_epochs} \ + --debug underflow_overflow \ + --adafactor \ + --gradient_checkpointing python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json diff --git a/convlab2/base_models/t5/dst/run_dst_pretrain.sh b/convlab2/base_models/t5/dst/run_dst_pretrain.sh new file mode 100644 index 00000000..f1c5c3d4 --- /dev/null +++ b/convlab2/base_models/t5/dst/run_dst_pretrain.sh @@ -0,0 +1,67 @@ +n_gpus=1 +task_name="dst" +dataset_name="sgd+tm1+tm2+tm3" +speaker="user" +context_window_size=100 +data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}" +output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}" +cache_dir="../cache" +logging_dir="${output_dir}/runs" +train_file="${data_dir}/train.json" +validation_file="${data_dir}/validation.json" +test_file="${data_dir}/test.json" +metric_name_or_path="dst_metric.py" +metric_for_best_model="accuracy" +source_column="context" +target_column="state_seq" +truncation_side="left" +max_source_length=1024 +max_target_length=512 +model_name_or_path="t5-small" +per_device_train_batch_size=64 +per_device_eval_batch_size=64 +gradient_accumulation_steps=2 +lr=1e-3 +num_train_epochs=1 + +names=$(echo ${dataset_name} | tr "+" "\n") +mkdir -p ${data_dir} +for name in ${names}; +do + echo "preprocessing ${name}" + python ../create_data.py -t ${task_name} -d ${name} -s ${speaker} -c ${context_window_size} + if [ "${name}" != "${dataset_name}" ]; then + cat "data/${task_name}/${name}/${speaker}/context_${context_window_size}/train.json" >> ${train_file} + cat "data/${task_name}/${name}/${speaker}/context_${context_window_size}/validation.json" >> ${validation_file} + cat "data/${task_name}/${name}/${speaker}/context_${context_window_size}/test.json" >> ${test_file} + fi +done + +python ../run_seq2seq.py \ + --task_name ${task_name} \ + --train_file ${train_file} \ + --validation_file ${validation_file} \ + --source_column ${source_column} \ + --target_column ${target_column} \ + --max_source_length ${max_source_length} \ + --max_target_length ${max_target_length} \ + --truncation_side ${truncation_side} \ + --model_name_or_path ${model_name_or_path} \ + --do_train \ + --do_eval \ + --save_strategy epoch \ + --evaluation_strategy epoch \ + --prediction_loss_only \ + --cache_dir ${cache_dir} \ + --output_dir ${output_dir} \ + --logging_dir ${logging_dir} \ + --overwrite_output_dir \ + --preprocessing_num_workers 4 \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --learning_rate ${lr} \ + --num_train_epochs ${num_train_epochs} \ + --debug underflow_overflow \ + --adafactor \ + --gradient_checkpointing diff --git a/convlab2/base_models/t5/run_seq2seq.py b/convlab2/base_models/t5/run_seq2seq.py index 1c57a3e1..0e2b5720 100644 --- a/convlab2/base_models/t5/run_seq2seq.py +++ b/convlab2/base_models/t5/run_seq2seq.py @@ -41,6 +41,7 @@ from transformers import ( HfArgumentParser, Seq2SeqTrainer, Seq2SeqTrainingArguments, + EarlyStoppingCallback, set_seed, ) from transformers.trainer_utils import EvalPrediction, get_last_checkpoint @@ -217,6 +218,9 @@ class DataTrainingArguments: source_prefix: Optional[str] = field( default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} ) + early_stopping_patience: Optional[int] = field( + default=0, metadata={"help": "early stopping patience, default is 0 which means not using early stopping."}, + ) def __post_init__(self): if ( @@ -561,6 +565,8 @@ def main(): data_collator=data_collator, compute_metrics=compute_metrics if training_args.predict_with_generate else None, ) + if data_args.early_stopping_patience > 0: + trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=data_args.early_stopping_patience)) # Training if training_args.do_train: -- GitLab