From 89e557bc0744218ff0da1efc90415958dcc7134e Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Wed, 27 Apr 2022 15:56:29 +0800
Subject: [PATCH] add nlg fewshot, update run_seq2seq for separated truncation

---
 convlab2/base_models/t5/create_data.py        |  4 +-
 .../base_models/t5/nlg/merge_predict_res.py   |  3 +-
 .../base_models/t5/nlg/run_nlg_fewshot.sh     | 83 +++++++++++++++++++
 convlab2/base_models/t5/run_seq2seq.py        | 12 ++-
 4 files changed, 97 insertions(+), 5 deletions(-)
 create mode 100644 convlab2/base_models/t5/nlg/run_nlg_fewshot.sh

diff --git a/convlab2/base_models/t5/create_data.py b/convlab2/base_models/t5/create_data.py
index b2091f52..305538ed 100644
--- a/convlab2/base_models/t5/create_data.py
+++ b/convlab2/base_models/t5/create_data.py
@@ -87,9 +87,9 @@ def create_nlg_data(dataset, data_dir, args):
             dialogue_acts_seq = serialize_dialogue_acts(sample['dialogue_acts'])
             if args.context_window_size>0:
                 context = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[f'{sample["speaker"]}: '])
-                context = f'{dialogue_acts_seq}\n{context}'
+                context = f'{dialogue_acts_seq}\n\n{context}'
             else:
-                context = f'{dialogue_acts_seq}\n{sample["speaker"]}: '
+                context = f'{dialogue_acts_seq}\n\n{sample["speaker"]}: '
             assert equal_da_seq(sample['dialogue_acts'], dialogue_acts_seq), print(sample['dialogue_acts'], dialogue_acts_seq, deserialize_dialogue_acts(dialogue_acts_seq))
             data.append(json.dumps({'context+da': context, 'response': sample['utterance']}, ensure_ascii=False)+'\n')
 
diff --git a/convlab2/base_models/t5/nlg/merge_predict_res.py b/convlab2/base_models/t5/nlg/merge_predict_res.py
index 91e6055e..205226fa 100755
--- a/convlab2/base_models/t5/nlg/merge_predict_res.py
+++ b/convlab2/base_models/t5/nlg/merge_predict_res.py
@@ -5,7 +5,7 @@ from convlab2.util import load_dataset, load_nlg_data
 
 def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
     assert os.path.exists(predict_result)
-    dataset = load_dataset(dataset_name)
+    dataset = load_dataset(dataset_name, args.dial_ids_order)
     data = load_nlg_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
     
     if save_dir is None:
@@ -28,6 +28,7 @@ if __name__ == '__main__':
     parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result')
     parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered')
     parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json')
+    parser.add_argument('--dial_ids_order', '-o', type=int, default=None, help='which data order is used for experiments')
     args = parser.parse_args()
     print(args)
     merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result)
diff --git a/convlab2/base_models/t5/nlg/run_nlg_fewshot.sh b/convlab2/base_models/t5/nlg/run_nlg_fewshot.sh
new file mode 100644
index 00000000..97bdc9ec
--- /dev/null
+++ b/convlab2/base_models/t5/nlg/run_nlg_fewshot.sh
@@ -0,0 +1,83 @@
+n_gpus=1
+task_name="nlg"
+dataset_name=$1
+speaker="system"
+context_window_size=$2
+ratio=$3
+dial_ids_order=$4
+data_dir="data/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}"
+output_dir="output/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}"
+cache_dir="../cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+test_file="${data_dir}/test.json"
+metric_name_or_path="nlg_metric.py"
+metric_for_best_model="bleu"
+source_column="context+da"
+target_column="response"
+truncation_side="left"
+max_source_length=512
+max_target_length=512
+model_name_or_path="t5-small"
+per_device_train_batch_size=128
+per_device_eval_batch_size=64
+gradient_accumulation_steps=4
+lr=1e-3
+num_train_epochs=100
+
+python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} -r ${ratio} -o ${dial_ids_order}
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --train_file ${train_file} \
+    --validation_file ${validation_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${model_name_or_path} \
+    --do_train \
+    --do_eval \
+    --save_strategy epoch \
+    --evaluation_strategy epoch \
+    --save_total_limit 3 \
+    --prediction_loss_only \
+    --load_best_model_at_end \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --debug underflow_overflow \
+    --adafactor \
+    --gradient_checkpointing
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --test_file ${test_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${output_dir} \
+    --do_predict \
+    --predict_with_generate \
+    --metric_name_or_path ${metric_name_or_path} \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_eval_batch_size ${per_device_eval_batch_size}
+
+python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order}
+
+python ../../../nlg/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
diff --git a/convlab2/base_models/t5/run_seq2seq.py b/convlab2/base_models/t5/run_seq2seq.py
index 2f0f5481..1c57a3e1 100644
--- a/convlab2/base_models/t5/run_seq2seq.py
+++ b/convlab2/base_models/t5/run_seq2seq.py
@@ -25,6 +25,8 @@ import sys
 import json
 from dataclasses import dataclass, field
 from typing import Optional
+from itertools import zip_longest
+from functools import reduce
 
 import datasets
 import numpy as np
@@ -439,8 +441,14 @@ def main():
                 inputs.append(examples[source_column][i])
                 targets.append(examples[target_column][i])
 
-        inputs = [prefix + inp for inp in inputs]
-        model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
+        inputs = [prefix + '\n\n' + inp for inp in inputs]
+        if padding:
+            model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
+        else:
+            # truncate each part separated by \n\n respectively
+            split_inputs = [inp.split('\n\n') for inp in inputs]
+            split_model_inputs = [tokenizer(x, max_length=data_args.max_source_length, padding=False, truncation=True) for x in split_inputs]
+            model_inputs = {k: [reduce(lambda x, y: x[:-1]+y, item[k]) for item in split_model_inputs] for k in split_model_inputs[0]}
 
         # Setup the tokenizer for targets
         with tokenizer.as_target_tokenizer():
-- 
GitLab