Skip to content
Snippets Groups Projects
Commit 74d2408e authored by zqwerty's avatar zqwerty
Browse files

run nlg fewshot

parent f7190676
No related branches found
No related tags found
No related merge requests found
...@@ -14,7 +14,7 @@ metric_name_or_path="nlg_metric.py" ...@@ -14,7 +14,7 @@ metric_name_or_path="nlg_metric.py"
metric_for_best_model="bleu" metric_for_best_model="bleu"
source_column="context+da" source_column="context+da"
target_column="response" target_column="response"
truncation_side="right" truncation_side="left"
max_source_length=512 max_source_length=512
max_target_length=512 max_target_length=512
model_name_or_path="t5-small" model_name_or_path="t5-small"
...@@ -40,6 +40,7 @@ python ../run_seq2seq.py \ ...@@ -40,6 +40,7 @@ python ../run_seq2seq.py \
--do_eval \ --do_eval \
--save_strategy epoch \ --save_strategy epoch \
--evaluation_strategy epoch \ --evaluation_strategy epoch \
--save_total_limit 3 \
--prediction_loss_only \ --prediction_loss_only \
--cache_dir ${cache_dir} \ --cache_dir ${cache_dir} \
--output_dir ${output_dir} \ --output_dir ${output_dir} \
...@@ -72,8 +73,15 @@ python ../run_seq2seq.py \ ...@@ -72,8 +73,15 @@ python ../run_seq2seq.py \
--logging_dir ${logging_dir} \ --logging_dir ${logging_dir} \
--overwrite_output_dir \ --overwrite_output_dir \
--preprocessing_num_workers 4 \ --preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size} --per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--debug underflow_overflow \
--adafactor \
--gradient_checkpointing
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
python ../../../nlg/evaluate_unified_datasets.py -p ${output_dir}/predictions.json python ../../../nlg/evaluate_unified_datasets.py -p ${output_dir}/predictions.json --dataset_name ${dataset_name}
...@@ -76,8 +76,15 @@ python ../run_seq2seq.py \ ...@@ -76,8 +76,15 @@ python ../run_seq2seq.py \
--logging_dir ${logging_dir} \ --logging_dir ${logging_dir} \
--overwrite_output_dir \ --overwrite_output_dir \
--preprocessing_num_workers 4 \ --preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size} --per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--debug underflow_overflow \
--adafactor \
--gradient_checkpointing
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order} python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order}
python ../../../nlg/evaluate_unified_datasets.py -p ${output_dir}/predictions.json python ../../../nlg/evaluate_unified_datasets.py -p ${output_dir}/predictions.json --dataset_name ${dataset_name}
n_gpus=1
task_name="nlg"
dataset_name="sgd+tm1+tm2+tm3"
speaker="system"
context_window_size=0
data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
metric_name_or_path="nlg_metric.py"
metric_for_best_model="bleu"
source_column="context+da"
target_column="response"
truncation_side="left"
max_source_length=512
max_target_length=512
model_name_or_path="t5-small"
per_device_train_batch_size=64
per_device_eval_batch_size=64
gradient_accumulation_steps=8
lr=1e-3
num_train_epochs=1
names=$(echo ${dataset_name} | tr "+" "\n")
mkdir -p ${data_dir}
for name in ${names};
do
echo "preprocessing ${name}"
python ../create_data.py -t ${task_name} -d ${name} -s ${speaker} -c ${context_window_size}
if [ "${name}" != "${dataset_name}" ]; then
cat "data/${task_name}/${name}/${speaker}/context_${context_window_size}/train.json" >> ${train_file}
cat "data/${task_name}/${name}/${speaker}/context_${context_window_size}/validation.json" >> ${validation_file}
cat "data/${task_name}/${name}/${speaker}/context_${context_window_size}/test.json" >> ${test_file}
fi
done
python ../run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--save_strategy epoch \
--evaluation_strategy epoch \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--debug underflow_overflow \
--adafactor \
--gradient_checkpointing
...@@ -34,7 +34,7 @@ def evaluate(predict_result, ontology): ...@@ -34,7 +34,7 @@ def evaluate(predict_result, ontology):
candidates = [] candidates = []
for i in range(len(predict_result)): for i in range(len(predict_result)):
references.append(predict_result[i]['utterance']) references.append(predict_result[i]['utterance'])
candidates.append(predict_result[i]['prediction']) candidates.append(predict_result[i]['predictions']['utterance'])
# metrics['bleu'] = corpus_bleu(references, candidates) # metrics['bleu'] = corpus_bleu(references, candidates)
metrics['bleu'] = sacrebleu.corpus_bleu(candidates, [references], lowercase=True).score metrics['bleu'] = sacrebleu.corpus_bleu(candidates, [references], lowercase=True).score
...@@ -54,7 +54,7 @@ def evaluate(predict_result, ontology): ...@@ -54,7 +54,7 @@ def evaluate(predict_result, ontology):
score_list = [] score_list = []
for item in predict_result: for item in predict_result:
da = item['dialogue_acts'] da = item['dialogue_acts']
utterance = item['prediction'] utterance = item['predictions']['utterance']
missing_count = 0 missing_count = 0
redundant_count = 0 redundant_count = 0
all_count = 0 all_count = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment