From 2cb15bb250b2651f66e706471f81a17b7ae236ee Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Mon, 4 Jul 2022 15:53:10 +0800
Subject: [PATCH] nlg multitask exp

---
 convlab/base_models/t5/create_data.py         |  9 +-
 convlab/base_models/t5/nlg/merge_data.py      | 21 +++++
 .../base_models/t5/nlg/merge_predict_res.py   | 20 ++--
 convlab/base_models/t5/nlg/nlg_metric.py      |  1 +
 convlab/base_models/t5/nlg/run_nlg.sh         |  2 +-
 convlab/base_models/t5/nlg/run_nlg_fewshot.sh |  2 +-
 .../base_models/t5/nlg/run_nlg_multitask.sh   | 94 +++++++++++++++++++
 .../base_models/t5/nlg/run_nlg_pretrain.sh    |  7 +-
 convlab/nlg/evaluate_unified_datasets.py      |  1 +
 9 files changed, 141 insertions(+), 16 deletions(-)
 create mode 100644 convlab/base_models/t5/nlg/merge_data.py
 create mode 100644 convlab/base_models/t5/nlg/run_nlg_multitask.sh

diff --git a/convlab/base_models/t5/create_data.py b/convlab/base_models/t5/create_data.py
index 48685810..75c28284 100644
--- a/convlab/base_models/t5/create_data.py
+++ b/convlab/base_models/t5/create_data.py
@@ -85,6 +85,9 @@ def create_nlg_data(dataset, data_dir, args):
         data = []
         for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False):
             dialogue_acts_seq = serialize_dialogue_acts(sample['dialogue_acts'])
+            if len(dialogue_acts_seq) == 0:
+                # skip empty dialogue acts
+                continue
             if args.context_window_size>0:
                 context = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[f'{sample["speaker"]}: '])
                 context = f'{dialogue_acts_seq}\n\n{context}'
@@ -145,10 +148,10 @@ if __name__ == '__main__':
     if args.len_tokenizer:
         tokenizer = AutoTokenizer.from_pretrained(args.len_tokenizer)
     for dataset_name in tqdm(args.datasets, desc='datasets'):
-        dataset = load_dataset(dataset_name, args.dial_ids_order)
         if args.ratio:
-            dataset['train'] = dataset['train'][:round(len(dataset['train'])*args.ratio)]
-            dataset['validation'] = dataset['validation'][:round(len(dataset['validation'])*args.ratio)]
+            dataset = load_dataset(dataset_name, dial_ids_order=args.dial_ids_order, split2ratio={'train': args.ratio, 'validation': args.ratio})
+        else:
+            dataset = load_dataset(dataset_name, args.dial_ids_order)
         for task_name in tqdm(args.tasks, desc='tasks', leave=False):
             data_dir = os.path.join('data', task_name, (dataset_name if not args.ratio else f'{dataset_name}_{args.ratio}_order{args.dial_ids_order}'))
             data_by_split = eval(f"create_{task_name}_data")(dataset, data_dir, args)
diff --git a/convlab/base_models/t5/nlg/merge_data.py b/convlab/base_models/t5/nlg/merge_data.py
new file mode 100644
index 00000000..6b3f843a
--- /dev/null
+++ b/convlab/base_models/t5/nlg/merge_data.py
@@ -0,0 +1,21 @@
+import json
+import os
+import sys
+
+if __name__ == '__main__':
+    merged_data = {'train': [], 'validation': [], 'test': []}
+    print(sys.argv)
+    for dataset_name in sys.argv[1:]:
+        data_dir = os.path.join('data/nlg', dataset_name, 'system/context_0')
+        for data_split in merged_data:
+            with open(os.path.join(data_dir, f'{data_split}.json'), 'r') as f:
+                for line in f:
+                    item = json.loads(line)
+                    item['context+da'] = f"{dataset_name}: {item['context+da']}"
+                    merged_data[data_split].append(item)
+    for data_split in merged_data:
+        data_dir = os.path.join('data/nlg', '+'.join(sys.argv[1:]), 'system/context_0')
+        os.makedirs(data_dir, exist_ok=True)
+        with open(os.path.join(data_dir, f'{data_split}.json'), 'w') as f:
+            for item in merged_data[data_split]:
+                f.write(json.dumps(item)+'\n')
diff --git a/convlab/base_models/t5/nlg/merge_predict_res.py b/convlab/base_models/t5/nlg/merge_predict_res.py
index 7de38fa1..d21fd489 100755
--- a/convlab/base_models/t5/nlg/merge_predict_res.py
+++ b/convlab/base_models/t5/nlg/merge_predict_res.py
@@ -3,10 +3,8 @@ import os
 from convlab.util import load_dataset, load_nlg_data
 
 
-def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
+def merge(dataset_names, speaker, save_dir, context_window_size, predict_result):
     assert os.path.exists(predict_result)
-    dataset = load_dataset(dataset_name, args.dial_ids_order)
-    data = load_nlg_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
     
     if save_dir is None:
         save_dir = os.path.dirname(predict_result)
@@ -14,10 +12,20 @@ def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
         os.makedirs(save_dir, exist_ok=True)
     predict_result = [json.loads(x)['predictions'].strip() for x in open(predict_result)]
 
-    for sample, prediction in zip(data, predict_result):
-        sample['predictions'] = {'utterance': prediction}
+    merged = []
+    i = 0
+    for dataset_name in dataset_names.split('+'):
+        print(dataset_name)
+        dataset = load_dataset(dataset_name, args.dial_ids_order)
+        data = load_nlg_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
+    
+        for sample in data:
+            if all([len(sample['dialogue_acts'][da_type])==0 for da_type in sample['dialogue_acts']]):
+                continue
+            sample['predictions'] = {'utterance': predict_result[i]}
+            i += 1
 
-    json.dump(data, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+    json.dump(merged, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
 
 
 if __name__ == '__main__':
diff --git a/convlab/base_models/t5/nlg/nlg_metric.py b/convlab/base_models/t5/nlg/nlg_metric.py
index 0c0155ff..9a59500d 100644
--- a/convlab/base_models/t5/nlg/nlg_metric.py
+++ b/convlab/base_models/t5/nlg/nlg_metric.py
@@ -73,6 +73,7 @@ class NLGMetrics(datasets.Metric):
 
     def _compute(self, predictions, references):
         """Returns the scores: bleu"""
+        references = [" " if ref=="" else ref for ref in references]
         bleu = sacrebleu.corpus_bleu(predictions, [references], lowercase=True).score
         
         return {
diff --git a/convlab/base_models/t5/nlg/run_nlg.sh b/convlab/base_models/t5/nlg/run_nlg.sh
index 9de7fece..c45079a6 100644
--- a/convlab/base_models/t5/nlg/run_nlg.sh
+++ b/convlab/base_models/t5/nlg/run_nlg.sh
@@ -40,7 +40,7 @@ python ../run_seq2seq.py \
     --do_eval \
     --save_strategy epoch \
     --evaluation_strategy epoch \
-    --save_total_limit 3 \
+    --save_total_limit 1 \
     --prediction_loss_only \
     --cache_dir ${cache_dir} \
     --output_dir ${output_dir} \
diff --git a/convlab/base_models/t5/nlg/run_nlg_fewshot.sh b/convlab/base_models/t5/nlg/run_nlg_fewshot.sh
index 6f7c8d17..4e00fb9d 100644
--- a/convlab/base_models/t5/nlg/run_nlg_fewshot.sh
+++ b/convlab/base_models/t5/nlg/run_nlg_fewshot.sh
@@ -42,7 +42,7 @@ python ../run_seq2seq.py \
     --do_eval \
     --save_strategy epoch \
     --evaluation_strategy epoch \
-    --save_total_limit 3 \
+    --save_total_limit 1 \
     --prediction_loss_only \
     --load_best_model_at_end \
     --cache_dir ${cache_dir} \
diff --git a/convlab/base_models/t5/nlg/run_nlg_multitask.sh b/convlab/base_models/t5/nlg/run_nlg_multitask.sh
new file mode 100644
index 00000000..9b0a3d47
--- /dev/null
+++ b/convlab/base_models/t5/nlg/run_nlg_multitask.sh
@@ -0,0 +1,94 @@
+n_gpus=1
+task_name="nlg"
+dataset_name="sgd+tm1+tm2+tm3+multiwoz21"
+speaker="system"
+context_window_size=0
+data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
+output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
+cache_dir="../cache"
+logging_dir="${output_dir}/runs"
+train_file="${data_dir}/train.json"
+validation_file="${data_dir}/validation.json"
+test_file="${data_dir}/test.json"
+metric_name_or_path="nlg_metric.py"
+metric_for_best_model="bleu"
+source_column="context+da"
+target_column="response"
+truncation_side="left"
+max_source_length=512
+max_target_length=512
+model_name_or_path="t5-small"
+per_device_train_batch_size=64
+per_device_eval_batch_size=64
+gradient_accumulation_steps=8
+lr=1e-3
+num_train_epochs=10
+
+names=$(echo ${dataset_name} | tr "+" "\n")
+rm -r ${data_dir}
+mkdir -p ${data_dir}
+for name in ${names};
+do
+    echo "preprocessing ${name}"
+    python ../create_data.py -t ${task_name} -d ${name} -s ${speaker} -c ${context_window_size}
+done
+
+python merge_data.py $(echo ${dataset_name} | tr "+" " ")
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --train_file ${train_file} \
+    --validation_file ${validation_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${model_name_or_path} \
+    --do_train \
+    --do_eval \
+    --save_strategy epoch \
+    --evaluation_strategy epoch \
+    --save_total_limit 1 \
+    --prediction_loss_only \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
+
+python ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --test_file ${test_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${output_dir} \
+    --do_predict \
+    --predict_with_generate \
+    --metric_name_or_path ${metric_name_or_path} \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_train_batch_size ${per_device_train_batch_size} \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
+    --gradient_accumulation_steps ${gradient_accumulation_steps} \
+    --learning_rate ${lr} \
+    --num_train_epochs ${num_train_epochs} \
+    --adafactor \
+    --gradient_checkpointing
+
+python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
+
+# python ../../../nlg/evaluate_unified_datasets.py -p ${output_dir}/predictions.json --dataset_name ${dataset_name}
diff --git a/convlab/base_models/t5/nlg/run_nlg_pretrain.sh b/convlab/base_models/t5/nlg/run_nlg_pretrain.sh
index 8af5dd10..a1a1b601 100644
--- a/convlab/base_models/t5/nlg/run_nlg_pretrain.sh
+++ b/convlab/base_models/t5/nlg/run_nlg_pretrain.sh
@@ -31,13 +31,10 @@ for name in ${names};
 do
     echo "preprocessing ${name}"
     python ../create_data.py -t ${task_name} -d ${name} -s ${speaker} -c ${context_window_size}
-    if [ "${name}" != "${dataset_name}" ]; then
-        cat "data/${task_name}/${name}/${speaker}/context_${context_window_size}/train.json" >> ${train_file}
-        cat "data/${task_name}/${name}/${speaker}/context_${context_window_size}/validation.json" >> ${validation_file}
-        cat "data/${task_name}/${name}/${speaker}/context_${context_window_size}/test.json" >> ${test_file}
-    fi
 done
 
+python merge_data.py $(echo ${dataset_name} | tr "+" " ")
+
 python ../run_seq2seq.py \
     --task_name ${task_name} \
     --train_file ${train_file} \
diff --git a/convlab/nlg/evaluate_unified_datasets.py b/convlab/nlg/evaluate_unified_datasets.py
index b8806dc6..7a19a492 100644
--- a/convlab/nlg/evaluate_unified_datasets.py
+++ b/convlab/nlg/evaluate_unified_datasets.py
@@ -36,6 +36,7 @@ def evaluate(predict_result, ontology):
         references.append(predict_result[i]['utterance'])
         candidates.append(predict_result[i]['predictions']['utterance'])
     # metrics['bleu'] = corpus_bleu(references, candidates)
+    references = [" " if ref=="" else ref for ref in references]
     metrics['bleu'] = sacrebleu.corpus_bleu(candidates, [references], lowercase=True).score
 
     # ERROR Rate
-- 
GitLab