diff --git a/convlab2/base_models/t5/dst/run_dst_fewshot.sh b/convlab2/base_models/t5/dst/run_dst_fewshot.sh new file mode 100644 index 0000000000000000000000000000000000000000..298a37f17a1c0817ff257742b5aa6e61bb9cd5d0 --- /dev/null +++ b/convlab2/base_models/t5/dst/run_dst_fewshot.sh @@ -0,0 +1,89 @@ +n_gpus=1 +task_name="dst" +dataset_name=$1 +speaker="user" +context_window_size=100 +data_dir="data/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}" +output_dir="output/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}" +cache_dir="../cache" +logging_dir="${output_dir}/runs" +train_file="${data_dir}/train.json" +validation_file="${data_dir}/validation.json" +test_file="${data_dir}/test.json" +metric_name_or_path="dst_metric.py" +metric_for_best_model="accuracy" +source_column="context" +target_column="state_seq" +truncation_side="left" +max_source_length=1024 +max_target_length=512 +model_name_or_path="t5-small" +per_device_train_batch_size=64 +per_device_eval_batch_size=64 +gradient_accumulation_steps=2 +lr=1e-3 +num_train_epochs=100 + +python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} -r ${ratio} -o ${dial_ids_order} + +python ../run_seq2seq.py \ + --task_name ${task_name} \ + --train_file ${train_file} \ + --validation_file ${validation_file} \ + --source_column ${source_column} \ + --target_column ${target_column} \ + --max_source_length ${max_source_length} \ + --max_target_length ${max_target_length} \ + --truncation_side ${truncation_side} \ + --model_name_or_path ${model_name_or_path} \ + --do_train \ + --do_eval \ + --save_strategy epoch \ + --evaluation_strategy epoch \ + --save_total_limit 3 \ + --early_stopping_patience 10 \ + --prediction_loss_only \ + --load_best_model_at_end \ + --cache_dir ${cache_dir} \ + --output_dir ${output_dir} \ + --logging_dir ${logging_dir} \ + --overwrite_output_dir \ + --preprocessing_num_workers 4 \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --learning_rate ${lr} \ + --num_train_epochs ${num_train_epochs} \ + --debug underflow_overflow \ + --adafactor \ + --gradient_checkpointing + +python ../run_seq2seq.py \ + --task_name ${task_name} \ + --test_file ${test_file} \ + --source_column ${source_column} \ + --target_column ${target_column} \ + --max_source_length ${max_source_length} \ + --max_target_length ${max_target_length} \ + --truncation_side ${truncation_side} \ + --model_name_or_path ${output_dir} \ + --do_predict \ + --predict_with_generate \ + --metric_name_or_path ${metric_name_or_path} \ + --cache_dir ${cache_dir} \ + --output_dir ${output_dir} \ + --logging_dir ${logging_dir} \ + --overwrite_output_dir \ + --preprocessing_num_workers 4 \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --learning_rate ${lr} \ + --num_train_epochs ${num_train_epochs} \ + --debug underflow_overflow \ + --adafactor \ + --gradient_checkpointing + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order} + +python ../../../dst/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab2/base_models/t5/run_seq2seq.py b/convlab2/base_models/t5/run_seq2seq.py index 0e2b5720c80971faf0d044ee0fa30dba1a9eb26b..c76bb5cd690001e550aef4a1ce287d007c5a066d 100644 --- a/convlab2/base_models/t5/run_seq2seq.py +++ b/convlab2/base_models/t5/run_seq2seq.py @@ -219,7 +219,7 @@ class DataTrainingArguments: default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} ) early_stopping_patience: Optional[int] = field( - default=0, metadata={"help": "early stopping patience, default is 0 which means not using early stopping."}, + default=10, metadata={"help": "early stopping patience, set to 0 if you do not want to use early stopping."}, ) def __post_init__(self):