From 6daadd33e8c37dd1d12c3c895a44a76d4048a9f5 Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Sun, 10 Apr 2022 13:15:37 +0800
Subject: [PATCH] fix import bug in dst/merge_predict_res.py

---
 convlab2/base_models/t5/dst/dst_metric.py     |  4 +-
 .../base_models/t5/dst/merge_predict_res.py   |  4 +-
 convlab2/base_models/t5/dst/run_multiwoz21.sh | 43 +++++++++----------
 convlab2/base_models/t5/rg/run_rg.sh          | 28 ++++++------
 4 files changed, 39 insertions(+), 40 deletions(-)

diff --git a/convlab2/base_models/t5/dst/dst_metric.py b/convlab2/base_models/t5/dst/dst_metric.py
index 8a4f73b0..aedef34d 100644
--- a/convlab2/base_models/t5/dst/dst_metric.py
+++ b/convlab2/base_models/t5/dst/dst_metric.py
@@ -75,8 +75,8 @@ class DSTMetrics(datasets.Metric):
             pred_state = deserialize_dialogue_state(prediction)
             gold_state = deserialize_dialogue_state(reference)
 
-            predicts = sorted(list({(domain, slot, value) for domain in pred_state for slot, value in pred_state[domain].items() if len(value)>0}))
-            labels = sorted(list({(domain, slot, value) for domain in gold_state for slot, value in gold_state[domain].items() if len(value)>0}))
+            predicts = sorted(list({(domain, slot, ''.join(value.split()).lower()) for domain in pred_state for slot, value in pred_state[domain].items() if len(value)>0}))
+            labels = sorted(list({(domain, slot, ''.join(value.split()).lower()) for domain in gold_state for slot, value in gold_state[domain].items() if len(value)>0}))
 
             flag = True
             for ele in predicts:
diff --git a/convlab2/base_models/t5/dst/merge_predict_res.py b/convlab2/base_models/t5/dst/merge_predict_res.py
index ebdada8a..0a80ee80 100755
--- a/convlab2/base_models/t5/dst/merge_predict_res.py
+++ b/convlab2/base_models/t5/dst/merge_predict_res.py
@@ -1,7 +1,7 @@
 import json
 import os
 from convlab2.util import load_dataset, load_dst_data
-from convlab2.base_models.t5.dst.serialization import deserialize_state
+from convlab2.base_models.t5.dst.serialization import deserialize_dialogue_state
 
 
 def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
@@ -13,7 +13,7 @@ def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
         save_dir = os.path.dirname(predict_result)
     else:
         os.makedirs(save_dir, exist_ok=True)
-    predict_result = [deserialize_state(json.loads(x)['predictions'].strip()) for x in open(predict_result)]
+    predict_result = [deserialize_dialogue_state(json.loads(x)['predictions'].strip()) for x in open(predict_result)]
 
     for sample, prediction in zip(data, predict_result):
         sample['predictions'] = {'state': prediction}
diff --git a/convlab2/base_models/t5/dst/run_multiwoz21.sh b/convlab2/base_models/t5/dst/run_multiwoz21.sh
index e031be48..e7573e95 100644
--- a/convlab2/base_models/t5/dst/run_multiwoz21.sh
+++ b/convlab2/base_models/t5/dst/run_multiwoz21.sh
@@ -26,7 +26,7 @@ num_train_epochs=10
 
 python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --speaker ${speaker} --context_window_size ${context_window_size}
 
-python -m torch.distributed.launch --master_port 29501 \
+python -m torch.distributed.launch \
     --nproc_per_node ${n_gpus} ../run_seq2seq.py \
     --task_name ${task_name} \
     --train_file ${train_file} \
@@ -43,8 +43,7 @@ python -m torch.distributed.launch --master_port 29501 \
     --do_predict \
     --save_strategy epoch \
     --evaluation_strategy epoch \
-    --load_best_model_at_end \
-    --predict_with_generate \
+    --prediction_loss_only \
     --metric_name_or_path ${metric_name_or_path} \
     --cache_dir ${cache_dir} \
     --output_dir ${output_dir} \
@@ -60,24 +59,24 @@ python -m torch.distributed.launch --master_port 29501 \
     --adafactor \
     --gradient_checkpointing
 
-# python -m torch.distributed.launch \
-#     --nproc_per_node ${n_gpus} ../run_seq2seq.py \
-#     --task_name ${task_name} \
-#     --test_file ${test_file} \
-#     --source_column ${source_column} \
-#     --target_column ${target_column} \
-#     --max_source_length ${max_source_length} \
-#     --max_target_length ${max_target_length} \
-#     --truncation_side ${truncation_side} \
-#     --model_name_or_path ${output_dir} \
-#     --do_predict \
-#     --predict_with_generate \
-#     --metric_name_or_path ${metric_name_or_path} \
-#     --cache_dir ${cache_dir} \
-#     --output_dir ${output_dir} \
-#     --logging_dir ${logging_dir} \
-#     --overwrite_output_dir \
-#     --preprocessing_num_workers 4 \
-#     --per_device_eval_batch_size ${per_device_eval_batch_size} \
+python -m torch.distributed.launch \
+    --nproc_per_node ${n_gpus} ../run_seq2seq.py \
+    --task_name ${task_name} \
+    --test_file ${test_file} \
+    --source_column ${source_column} \
+    --target_column ${target_column} \
+    --max_source_length ${max_source_length} \
+    --max_target_length ${max_target_length} \
+    --truncation_side ${truncation_side} \
+    --model_name_or_path ${output_dir} \
+    --do_predict \
+    --predict_with_generate \
+    --metric_name_or_path ${metric_name_or_path} \
+    --cache_dir ${cache_dir} \
+    --output_dir ${output_dir} \
+    --logging_dir ${logging_dir} \
+    --overwrite_output_dir \
+    --preprocessing_num_workers 4 \
+    --per_device_eval_batch_size ${per_device_eval_batch_size} \
 
 python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
diff --git a/convlab2/base_models/t5/rg/run_rg.sh b/convlab2/base_models/t5/rg/run_rg.sh
index 55accadf..6fcffca2 100644
--- a/convlab2/base_models/t5/rg/run_rg.sh
+++ b/convlab2/base_models/t5/rg/run_rg.sh
@@ -16,24 +16,24 @@ truncation_side="left"
 max_source_length=512
 max_target_length=128
 model_name_or_path="t5-small"
-per_device_train_batch_size=32
+per_device_train_batch_size=128
 per_device_eval_batch_size=128
 gradient_accumulation_steps=4
 lr=1e-3
 num_train_epochs=5
 
-# names=$(echo ${dataset_name} | tr "+" "\n")
-# mkdir -p ${data_dir}
-# for name in ${names};
-# do
-#     echo "preprocessing ${name}"
-#     python ../create_data.py --tasks ${task_name} --datasets ${name} --speaker ${speaker}
-#     if [ "${name}" != "${dataset_name}" ]; then
-#         cat "data/${task_name}/${name}/${speaker}/train.json" >> ${train_file}
-#         cat "data/${task_name}/${name}/${speaker}/validation.json" >> ${validation_file}
-#         cat "data/${task_name}/${name}/${speaker}/test.json" >> ${test_file}
-#     fi
-# done
+names=$(echo ${dataset_name} | tr "+" "\n")
+mkdir -p ${data_dir}
+for name in ${names};
+do
+    echo "preprocessing ${name}"
+    python ../create_data.py --tasks ${task_name} --datasets ${name} --speaker ${speaker}
+    if [ "${name}" != "${dataset_name}" ]; then
+        cat "data/${task_name}/${name}/${speaker}/train.json" >> ${train_file}
+        cat "data/${task_name}/${name}/${speaker}/validation.json" >> ${validation_file}
+        cat "data/${task_name}/${name}/${speaker}/test.json" >> ${test_file}
+    fi
+done
 
 python -m torch.distributed.launch \
     --nproc_per_node ${n_gpus} ../run_seq2seq.py \
@@ -53,7 +53,7 @@ python -m torch.distributed.launch \
     --save_strategy epoch \
     --evaluation_strategy epoch \
     --load_best_model_at_end \
-    --predict_with_generate \
+    --prediction_loss_only \
     --cache_dir ${cache_dir} \
     --output_dir ${output_dir} \
     --logging_dir ${logging_dir} \
-- 
GitLab