Skip to content
Snippets Groups Projects
Commit 6daadd33 authored by zqwerty's avatar zqwerty
Browse files

fix import bug in dst/merge_predict_res.py

parent 9cf97f07
Branches
No related tags found
No related merge requests found
...@@ -75,8 +75,8 @@ class DSTMetrics(datasets.Metric): ...@@ -75,8 +75,8 @@ class DSTMetrics(datasets.Metric):
pred_state = deserialize_dialogue_state(prediction) pred_state = deserialize_dialogue_state(prediction)
gold_state = deserialize_dialogue_state(reference) gold_state = deserialize_dialogue_state(reference)
predicts = sorted(list({(domain, slot, value) for domain in pred_state for slot, value in pred_state[domain].items() if len(value)>0})) predicts = sorted(list({(domain, slot, ''.join(value.split()).lower()) for domain in pred_state for slot, value in pred_state[domain].items() if len(value)>0}))
labels = sorted(list({(domain, slot, value) for domain in gold_state for slot, value in gold_state[domain].items() if len(value)>0})) labels = sorted(list({(domain, slot, ''.join(value.split()).lower()) for domain in gold_state for slot, value in gold_state[domain].items() if len(value)>0}))
flag = True flag = True
for ele in predicts: for ele in predicts:
......
import json import json
import os import os
from convlab2.util import load_dataset, load_dst_data from convlab2.util import load_dataset, load_dst_data
from convlab2.base_models.t5.dst.serialization import deserialize_state from convlab2.base_models.t5.dst.serialization import deserialize_dialogue_state
def merge(dataset_name, speaker, save_dir, context_window_size, predict_result): def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
...@@ -13,7 +13,7 @@ def merge(dataset_name, speaker, save_dir, context_window_size, predict_result): ...@@ -13,7 +13,7 @@ def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
save_dir = os.path.dirname(predict_result) save_dir = os.path.dirname(predict_result)
else: else:
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
predict_result = [deserialize_state(json.loads(x)['predictions'].strip()) for x in open(predict_result)] predict_result = [deserialize_dialogue_state(json.loads(x)['predictions'].strip()) for x in open(predict_result)]
for sample, prediction in zip(data, predict_result): for sample, prediction in zip(data, predict_result):
sample['predictions'] = {'state': prediction} sample['predictions'] = {'state': prediction}
......
...@@ -26,7 +26,7 @@ num_train_epochs=10 ...@@ -26,7 +26,7 @@ num_train_epochs=10
python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --speaker ${speaker} --context_window_size ${context_window_size} python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --speaker ${speaker} --context_window_size ${context_window_size}
python -m torch.distributed.launch --master_port 29501 \ python -m torch.distributed.launch \
--nproc_per_node ${n_gpus} ../run_seq2seq.py \ --nproc_per_node ${n_gpus} ../run_seq2seq.py \
--task_name ${task_name} \ --task_name ${task_name} \
--train_file ${train_file} \ --train_file ${train_file} \
...@@ -43,8 +43,7 @@ python -m torch.distributed.launch --master_port 29501 \ ...@@ -43,8 +43,7 @@ python -m torch.distributed.launch --master_port 29501 \
--do_predict \ --do_predict \
--save_strategy epoch \ --save_strategy epoch \
--evaluation_strategy epoch \ --evaluation_strategy epoch \
--load_best_model_at_end \ --prediction_loss_only \
--predict_with_generate \
--metric_name_or_path ${metric_name_or_path} \ --metric_name_or_path ${metric_name_or_path} \
--cache_dir ${cache_dir} \ --cache_dir ${cache_dir} \
--output_dir ${output_dir} \ --output_dir ${output_dir} \
...@@ -60,24 +59,24 @@ python -m torch.distributed.launch --master_port 29501 \ ...@@ -60,24 +59,24 @@ python -m torch.distributed.launch --master_port 29501 \
--adafactor \ --adafactor \
--gradient_checkpointing --gradient_checkpointing
# python -m torch.distributed.launch \ python -m torch.distributed.launch \
# --nproc_per_node ${n_gpus} ../run_seq2seq.py \ --nproc_per_node ${n_gpus} ../run_seq2seq.py \
# --task_name ${task_name} \ --task_name ${task_name} \
# --test_file ${test_file} \ --test_file ${test_file} \
# --source_column ${source_column} \ --source_column ${source_column} \
# --target_column ${target_column} \ --target_column ${target_column} \
# --max_source_length ${max_source_length} \ --max_source_length ${max_source_length} \
# --max_target_length ${max_target_length} \ --max_target_length ${max_target_length} \
# --truncation_side ${truncation_side} \ --truncation_side ${truncation_side} \
# --model_name_or_path ${output_dir} \ --model_name_or_path ${output_dir} \
# --do_predict \ --do_predict \
# --predict_with_generate \ --predict_with_generate \
# --metric_name_or_path ${metric_name_or_path} \ --metric_name_or_path ${metric_name_or_path} \
# --cache_dir ${cache_dir} \ --cache_dir ${cache_dir} \
# --output_dir ${output_dir} \ --output_dir ${output_dir} \
# --logging_dir ${logging_dir} \ --logging_dir ${logging_dir} \
# --overwrite_output_dir \ --overwrite_output_dir \
# --preprocessing_num_workers 4 \ --preprocessing_num_workers 4 \
# --per_device_eval_batch_size ${per_device_eval_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
...@@ -16,24 +16,24 @@ truncation_side="left" ...@@ -16,24 +16,24 @@ truncation_side="left"
max_source_length=512 max_source_length=512
max_target_length=128 max_target_length=128
model_name_or_path="t5-small" model_name_or_path="t5-small"
per_device_train_batch_size=32 per_device_train_batch_size=128
per_device_eval_batch_size=128 per_device_eval_batch_size=128
gradient_accumulation_steps=4 gradient_accumulation_steps=4
lr=1e-3 lr=1e-3
num_train_epochs=5 num_train_epochs=5
# names=$(echo ${dataset_name} | tr "+" "\n") names=$(echo ${dataset_name} | tr "+" "\n")
# mkdir -p ${data_dir} mkdir -p ${data_dir}
# for name in ${names}; for name in ${names};
# do do
# echo "preprocessing ${name}" echo "preprocessing ${name}"
# python ../create_data.py --tasks ${task_name} --datasets ${name} --speaker ${speaker} python ../create_data.py --tasks ${task_name} --datasets ${name} --speaker ${speaker}
# if [ "${name}" != "${dataset_name}" ]; then if [ "${name}" != "${dataset_name}" ]; then
# cat "data/${task_name}/${name}/${speaker}/train.json" >> ${train_file} cat "data/${task_name}/${name}/${speaker}/train.json" >> ${train_file}
# cat "data/${task_name}/${name}/${speaker}/validation.json" >> ${validation_file} cat "data/${task_name}/${name}/${speaker}/validation.json" >> ${validation_file}
# cat "data/${task_name}/${name}/${speaker}/test.json" >> ${test_file} cat "data/${task_name}/${name}/${speaker}/test.json" >> ${test_file}
# fi fi
# done done
python -m torch.distributed.launch \ python -m torch.distributed.launch \
--nproc_per_node ${n_gpus} ../run_seq2seq.py \ --nproc_per_node ${n_gpus} ../run_seq2seq.py \
...@@ -53,7 +53,7 @@ python -m torch.distributed.launch \ ...@@ -53,7 +53,7 @@ python -m torch.distributed.launch \
--save_strategy epoch \ --save_strategy epoch \
--evaluation_strategy epoch \ --evaluation_strategy epoch \
--load_best_model_at_end \ --load_best_model_at_end \
--predict_with_generate \ --prediction_loss_only \
--cache_dir ${cache_dir} \ --cache_dir ${cache_dir} \
--output_dir ${output_dir} \ --output_dir ${output_dir} \
--logging_dir ${logging_dir} \ --logging_dir ${logging_dir} \
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment