Skip to content
Snippets Groups Projects
Commit a26a42b2 authored by zqwerty's avatar zqwerty
Browse files

update run_seq2seq: add early stopping. update dst script to support fewshot learning

parent 74d2408e
Branches
No related tags found
No related merge requests found
......@@ -6,7 +6,7 @@ from convlab2.base_models.t5.dst.serialization import deserialize_dialogue_state
def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
assert os.path.exists(predict_result)
dataset = load_dataset(dataset_name)
dataset = load_dataset(dataset_name, args.dial_ids_order)
data = load_dst_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
if save_dir is None:
......@@ -29,6 +29,7 @@ if __name__ == '__main__':
parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result')
parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered')
parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json')
parser.add_argument('--dial_ids_order', '-o', type=int, default=None, help='which data order is used for experiments')
args = parser.parse_args()
print(args)
merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result)
......@@ -24,14 +24,12 @@ gradient_accumulation_steps=2
lr=1e-3
num_train_epochs=10
python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} -l t5-small
python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size}
python -m torch.distributed.launch \
--nproc_per_node ${n_gpus} ../run_seq2seq.py \
python ../run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
......@@ -40,9 +38,9 @@ python -m torch.distributed.launch \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--do_predict \
--save_strategy epoch \
--evaluation_strategy epoch \
--save_total_limit 3 \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
......@@ -58,8 +56,7 @@ python -m torch.distributed.launch \
--adafactor \
--gradient_checkpointing
python -m torch.distributed.launch \
--nproc_per_node ${n_gpus} ../run_seq2seq.py \
python ../run_seq2seq.py \
--task_name ${task_name} \
--test_file ${test_file} \
--source_column ${source_column} \
......@@ -76,7 +73,14 @@ python -m torch.distributed.launch \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size}
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--debug underflow_overflow \
--adafactor \
--gradient_checkpointing
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
......
n_gpus=1
task_name="dst"
dataset_name="sgd+tm1+tm2+tm3"
speaker="user"
context_window_size=100
data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
metric_name_or_path="dst_metric.py"
metric_for_best_model="accuracy"
source_column="context"
target_column="state_seq"
truncation_side="left"
max_source_length=1024
max_target_length=512
model_name_or_path="t5-small"
per_device_train_batch_size=64
per_device_eval_batch_size=64
gradient_accumulation_steps=2
lr=1e-3
num_train_epochs=1
names=$(echo ${dataset_name} | tr "+" "\n")
mkdir -p ${data_dir}
for name in ${names};
do
echo "preprocessing ${name}"
python ../create_data.py -t ${task_name} -d ${name} -s ${speaker} -c ${context_window_size}
if [ "${name}" != "${dataset_name}" ]; then
cat "data/${task_name}/${name}/${speaker}/context_${context_window_size}/train.json" >> ${train_file}
cat "data/${task_name}/${name}/${speaker}/context_${context_window_size}/validation.json" >> ${validation_file}
cat "data/${task_name}/${name}/${speaker}/context_${context_window_size}/test.json" >> ${test_file}
fi
done
python ../run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--save_strategy epoch \
--evaluation_strategy epoch \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--debug underflow_overflow \
--adafactor \
--gradient_checkpointing
......@@ -41,6 +41,7 @@ from transformers import (
HfArgumentParser,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
EarlyStoppingCallback,
set_seed,
)
from transformers.trainer_utils import EvalPrediction, get_last_checkpoint
......@@ -217,6 +218,9 @@ class DataTrainingArguments:
source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
early_stopping_patience: Optional[int] = field(
default=0, metadata={"help": "early stopping patience, default is 0 which means not using early stopping."},
)
def __post_init__(self):
if (
......@@ -561,6 +565,8 @@ def main():
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
if data_args.early_stopping_patience > 0:
trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=data_args.early_stopping_patience))
# Training
if training_args.do_train:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment