From a26a42b2b361f5234fd43bc9a6fe0cca9951a74b Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Fri, 29 Apr 2022 17:21:07 +0800
Subject: [PATCH] update run_seq2seq: add early stopping. update dst script to
 support fewshot learning

---
 .../base_models/t5/dst/merge_predict_res.py   |  3 +-
 convlab2/base_models/t5/dst/run_dst.sh        | 20 +++---
 .../base_models/t5/dst/run_dst_pretrain.sh    | 67 +++++++++++++++++++
 convlab2/base_models/t5/run_seq2seq.py        |  6 ++
 4 files changed, 87 insertions(+), 9 deletions(-)
 create mode 100644 convlab2/base_models/t5/dst/run_dst_pretrain.sh

diff --git a/convlab2/base_models/t5/dst/merge_predict_res.py b/convlab2/base_models/t5/dst/merge_predict_res.py
index 0a80ee80..9b942260 100755
--- a/convlab2/base_models/t5/dst/merge_predict_res.py
+++ b/convlab2/base_models/t5/dst/merge_predict_res.py
@@ -6,7 +6,7 @@ from convlab2.base_models.t5.dst.serialization import deserialize_dialogue_state
 
 def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
     assert os.path.exists(predict_result)
-    dataset = load_dataset(dataset_name)
+    dataset = load_dataset(dataset_name, args.dial_ids_order)
     data = load_dst_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
     
     if save_dir is None:
@@ -29,6 +29,7 @@ if __name__ == '__main__':
     parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result')
     parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered')
     parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json')
+    parser.add_argument('--dial_ids_order', '-o', type=int, default=None, help='which data order is used for experiments')
     args = parser.parse_args()
     print(args)
     merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result)
diff --git a/convlab2/base_models/t5/dst/run_dst.sh b/convlab2/base_models/t5/dst/run_dst.sh
index 7ee6041e..c678005e 100644
--- a/convlab2/base_models/t5/dst/run_dst.sh
+++ b/convlab2/base_models/t5/dst/run_dst.sh
@@ -24,14 +24,12 @@ gradient_accumulation_steps=2
 lr=1e-3
 num_train_epochs=10
 
-python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} -l t5-small
+python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size}
 
-python -m torch.distributed.launch \
-    --nproc_per_node ${n_gpus} ../run_seq2seq.py \
+python ../run_seq2seq.py \
     --task_name ${task_name} \
     --train_file ${train_file} \
     --validation_file ${validation_file} \
-    --test_file ${test_file} \
     --source_column ${source_column} \
     --target_column ${target_column} \
     --max_source_length ${max_source_length} \
@@ -40,9 +38,9 @@ python -m torch.distributed.launch \
     --model_name_or_path ${model_name_or_path} \
     --do_train \
     --do_eval \
-    --do_predict \
     --save_strategy epoch \
     --evaluation_strategy epoch \
+    --save_total_limit 3 \
     --prediction_loss_only \
     --cache_dir ${cache_dir} \
     --output_dir ${output_dir} \
@@ -58,8 +56,7 @@ python -m torch.distributed.launch \
     --adafactor \
     --gradient_checkpointing
 
-python -m torch.distributed.launch \
-    --nproc_per_node ${n_gpus} ../run_seq2seq.py \
+python ../run_seq2seq.py \
     --task_name ${task_name} \
     --test_file ${test_file} \
     --source_column ${source_column} \
@@ -76,7 +73,14 @@ python -m torch.distributed.launch \
     --logging_dir ${logging_dir} \
     --overwrite_output_dir \
     --preprocessing_num_workers 4 \
-    --per_device_eval_batch_size ${per_device_eval_batch_size}
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --debug underflow_overflow \
+    --adafactor \
+    --gradient_checkpointing
 
 python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
 
diff --git a/convlab2/base_models/t5/dst/run_dst_pretrain.sh b/convlab2/base_models/t5/dst/run_dst_pretrain.sh
new file mode 100644
index 00000000..f1c5c3d4
--- /dev/null
+++ b/convlab2/base_models/t5/dst/run_dst_pretrain.sh
@@ -0,0 +1,67 @@
+n_gpus=1
+task_name="dst"
+dataset_name="sgd+tm1+tm2+tm3"
+speaker="user"
+context_window_size=100
+data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
+output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
+cache_dir="../cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+test_file="${data_dir}/test.json"
+metric_name_or_path="dst_metric.py"
+metric_for_best_model="accuracy"
+source_column="context"
+target_column="state_seq"
+truncation_side="left"
+max_source_length=1024
+max_target_length=512
+model_name_or_path="t5-small"
+per_device_train_batch_size=64
+per_device_eval_batch_size=64
+gradient_accumulation_steps=2
+lr=1e-3
+num_train_epochs=1
+
+names=$(echo ${dataset_name} | tr "+" "\n")
+mkdir -p ${data_dir}
+for name in ${names};
+do
+    echo "preprocessing ${name}"
+    python ../create_data.py -t ${task_name} -d ${name} -s ${speaker} -c ${context_window_size}
+    if [ "${name}" != "${dataset_name}" ]; then
+        cat "data/${task_name}/${name}/${speaker}/context_${context_window_size}/train.json" >> ${train_file}
+        cat "data/${task_name}/${name}/${speaker}/context_${context_window_size}/validation.json" >> ${validation_file}
+        cat "data/${task_name}/${name}/${speaker}/context_${context_window_size}/test.json" >> ${test_file}
+    fi
+done
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --train_file ${train_file} \
+    --validation_file ${validation_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${model_name_or_path} \
+    --do_train \
+    --do_eval \
+    --save_strategy epoch \
+    --evaluation_strategy epoch \
+    --prediction_loss_only \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --debug underflow_overflow \
+    --adafactor \
+    --gradient_checkpointing
diff --git a/convlab2/base_models/t5/run_seq2seq.py b/convlab2/base_models/t5/run_seq2seq.py
index 1c57a3e1..0e2b5720 100644
--- a/convlab2/base_models/t5/run_seq2seq.py
+++ b/convlab2/base_models/t5/run_seq2seq.py
@@ -41,6 +41,7 @@ from transformers import (
     HfArgumentParser,
     Seq2SeqTrainer,
     Seq2SeqTrainingArguments,
+    EarlyStoppingCallback,
     set_seed,
 )
 from transformers.trainer_utils import EvalPrediction, get_last_checkpoint
@@ -217,6 +218,9 @@ class DataTrainingArguments:
     source_prefix: Optional[str] = field(
         default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
     )
+    early_stopping_patience: Optional[int] = field(
+        default=0, metadata={"help": "early stopping patience, default is 0 which means not using early stopping."},
+    )
 
     def __post_init__(self):
         if (
@@ -561,6 +565,8 @@ def main():
         data_collator=data_collator,
         compute_metrics=compute_metrics if training_args.predict_with_generate else None,
     )
+    if data_args.early_stopping_patience > 0:
+        trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=data_args.early_stopping_patience))
 
     # Training
     if training_args.do_train:
-- 
GitLab