Skip to content
Snippets Groups Projects
Commit 6daadd33 authored by zqwerty's avatar zqwerty
Browse files

fix import bug in dst/merge_predict_res.py

parent 9cf97f07
Branches
No related tags found
No related merge requests found
......@@ -75,8 +75,8 @@ class DSTMetrics(datasets.Metric):
pred_state = deserialize_dialogue_state(prediction)
gold_state = deserialize_dialogue_state(reference)
predicts = sorted(list({(domain, slot, value) for domain in pred_state for slot, value in pred_state[domain].items() if len(value)>0}))
labels = sorted(list({(domain, slot, value) for domain in gold_state for slot, value in gold_state[domain].items() if len(value)>0}))
predicts = sorted(list({(domain, slot, ''.join(value.split()).lower()) for domain in pred_state for slot, value in pred_state[domain].items() if len(value)>0}))
labels = sorted(list({(domain, slot, ''.join(value.split()).lower()) for domain in gold_state for slot, value in gold_state[domain].items() if len(value)>0}))
flag = True
for ele in predicts:
......
import json
import os
from convlab2.util import load_dataset, load_dst_data
from convlab2.base_models.t5.dst.serialization import deserialize_state
from convlab2.base_models.t5.dst.serialization import deserialize_dialogue_state
def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
......@@ -13,7 +13,7 @@ def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
save_dir = os.path.dirname(predict_result)
else:
os.makedirs(save_dir, exist_ok=True)
predict_result = [deserialize_state(json.loads(x)['predictions'].strip()) for x in open(predict_result)]
predict_result = [deserialize_dialogue_state(json.loads(x)['predictions'].strip()) for x in open(predict_result)]
for sample, prediction in zip(data, predict_result):
sample['predictions'] = {'state': prediction}
......
......@@ -26,7 +26,7 @@ num_train_epochs=10
python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --speaker ${speaker} --context_window_size ${context_window_size}
python -m torch.distributed.launch --master_port 29501 \
python -m torch.distributed.launch \
--nproc_per_node ${n_gpus} ../run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
......@@ -43,8 +43,7 @@ python -m torch.distributed.launch --master_port 29501 \
--do_predict \
--save_strategy epoch \
--evaluation_strategy epoch \
--load_best_model_at_end \
--predict_with_generate \
--prediction_loss_only \
--metric_name_or_path ${metric_name_or_path} \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
......@@ -60,24 +59,24 @@ python -m torch.distributed.launch --master_port 29501 \
--adafactor \
--gradient_checkpointing
# python -m torch.distributed.launch \
# --nproc_per_node ${n_gpus} ../run_seq2seq.py \
# --task_name ${task_name} \
# --test_file ${test_file} \
# --source_column ${source_column} \
# --target_column ${target_column} \
# --max_source_length ${max_source_length} \
# --max_target_length ${max_target_length} \
# --truncation_side ${truncation_side} \
# --model_name_or_path ${output_dir} \
# --do_predict \
# --predict_with_generate \
# --metric_name_or_path ${metric_name_or_path} \
# --cache_dir ${cache_dir} \
# --output_dir ${output_dir} \
# --logging_dir ${logging_dir} \
# --overwrite_output_dir \
# --preprocessing_num_workers 4 \
# --per_device_eval_batch_size ${per_device_eval_batch_size} \
python -m torch.distributed.launch \
--nproc_per_node ${n_gpus} ../run_seq2seq.py \
--task_name ${task_name} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${output_dir} \
--do_predict \
--predict_with_generate \
--metric_name_or_path ${metric_name_or_path} \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
......@@ -16,24 +16,24 @@ truncation_side="left"
max_source_length=512
max_target_length=128
model_name_or_path="t5-small"
per_device_train_batch_size=32
per_device_train_batch_size=128
per_device_eval_batch_size=128
gradient_accumulation_steps=4
lr=1e-3
num_train_epochs=5
# names=$(echo ${dataset_name} | tr "+" "\n")
# mkdir -p ${data_dir}
# for name in ${names};
# do
# echo "preprocessing ${name}"
# python ../create_data.py --tasks ${task_name} --datasets ${name} --speaker ${speaker}
# if [ "${name}" != "${dataset_name}" ]; then
# cat "data/${task_name}/${name}/${speaker}/train.json" >> ${train_file}
# cat "data/${task_name}/${name}/${speaker}/validation.json" >> ${validation_file}
# cat "data/${task_name}/${name}/${speaker}/test.json" >> ${test_file}
# fi
# done
names=$(echo ${dataset_name} | tr "+" "\n")
mkdir -p ${data_dir}
for name in ${names};
do
echo "preprocessing ${name}"
python ../create_data.py --tasks ${task_name} --datasets ${name} --speaker ${speaker}
if [ "${name}" != "${dataset_name}" ]; then
cat "data/${task_name}/${name}/${speaker}/train.json" >> ${train_file}
cat "data/${task_name}/${name}/${speaker}/validation.json" >> ${validation_file}
cat "data/${task_name}/${name}/${speaker}/test.json" >> ${test_file}
fi
done
python -m torch.distributed.launch \
--nproc_per_node ${n_gpus} ../run_seq2seq.py \
......@@ -53,7 +53,7 @@ python -m torch.distributed.launch \
--save_strategy epoch \
--evaluation_strategy epoch \
--load_best_model_at_end \
--predict_with_generate \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment