Skip to content
Snippets Groups Projects
Commit 3e77b6dc authored by zqwerty's avatar zqwerty
Browse files

add dst fewshot shell

parent a26a42b2
No related branches found
No related tags found
No related merge requests found
n_gpus=1
task_name="dst"
dataset_name=$1
speaker="user"
context_window_size=100
data_dir="data/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}"
output_dir="output/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
metric_name_or_path="dst_metric.py"
metric_for_best_model="accuracy"
source_column="context"
target_column="state_seq"
truncation_side="left"
max_source_length=1024
max_target_length=512
model_name_or_path="t5-small"
per_device_train_batch_size=64
per_device_eval_batch_size=64
gradient_accumulation_steps=2
lr=1e-3
num_train_epochs=100
python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} -r ${ratio} -o ${dial_ids_order}
python ../run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--save_strategy epoch \
--evaluation_strategy epoch \
--save_total_limit 3 \
--early_stopping_patience 10 \
--prediction_loss_only \
--load_best_model_at_end \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--debug underflow_overflow \
--adafactor \
--gradient_checkpointing
python ../run_seq2seq.py \
--task_name ${task_name} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${output_dir} \
--do_predict \
--predict_with_generate \
--metric_name_or_path ${metric_name_or_path} \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--debug underflow_overflow \
--adafactor \
--gradient_checkpointing
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order}
python ../../../dst/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
......@@ -219,7 +219,7 @@ class DataTrainingArguments:
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
early_stopping_patience: Optional[int] = field(
default=0, metadata={"help": "early stopping patience, default is 0 which means not using early stopping."},
default=10, metadata={"help": "early stopping patience, set to 0 if you do not want to use early stopping."},
)
def __post_init__(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment