Skip to content
Snippets Groups Projects
Commit 61b6c665 authored by zqwerty's avatar zqwerty
Browse files

t5dst multitask

parent 06ed57f4
No related branches found
No related tags found
No related merge requests found
import json
import os
import sys
if __name__ == '__main__':
merged_data = {'train': [], 'validation': [], 'test': []}
print(sys.argv)
for dataset_name in sys.argv[1:]:
data_dir = os.path.join('data/dst', dataset_name, 'user/context_100')
for data_split in merged_data:
with open(os.path.join(data_dir, f'{data_split}.json'), 'r') as f:
for line in f:
item = json.loads(line)
item['context'] = f"{dataset_name}: {item['context']}"
merged_data[data_split].append(item)
for data_split in merged_data:
data_dir = os.path.join('data/dst', '+'.join(sys.argv[1:]), 'user/context_100')
os.makedirs(data_dir, exist_ok=True)
with open(os.path.join(data_dir, f'{data_split}.json'), 'w') as f:
for item in merged_data[data_split]:
f.write(json.dumps(item)+'\n')
...@@ -4,10 +4,8 @@ from convlab.util import load_dataset, load_dst_data ...@@ -4,10 +4,8 @@ from convlab.util import load_dataset, load_dst_data
from convlab.base_models.t5.dst.serialization import deserialize_dialogue_state from convlab.base_models.t5.dst.serialization import deserialize_dialogue_state
def merge(dataset_name, speaker, save_dir, context_window_size, predict_result): def merge(dataset_names, speaker, save_dir, context_window_size, predict_result):
assert os.path.exists(predict_result) assert os.path.exists(predict_result)
dataset = load_dataset(dataset_name, args.dial_ids_order)
data = load_dst_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
if save_dir is None: if save_dir is None:
save_dir = os.path.dirname(predict_result) save_dir = os.path.dirname(predict_result)
...@@ -15,10 +13,19 @@ def merge(dataset_name, speaker, save_dir, context_window_size, predict_result): ...@@ -15,10 +13,19 @@ def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
predict_result = [deserialize_dialogue_state(json.loads(x)['predictions'].strip()) for x in open(predict_result)] predict_result = [deserialize_dialogue_state(json.loads(x)['predictions'].strip()) for x in open(predict_result)]
for sample, prediction in zip(data, predict_result): merged = []
sample['predictions'] = {'state': prediction} i = 0
for dataset_name in dataset_names.split('+'):
print(dataset_name)
dataset = load_dataset(dataset_name, args.dial_ids_order)
data = load_dst_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
for sample in data:
sample['predictions'] = {'state': predict_result[i]}
i += 1
merged.append(sample)
json.dump(data, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) json.dump(merged, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -40,7 +40,7 @@ python ../run_seq2seq.py \ ...@@ -40,7 +40,7 @@ python ../run_seq2seq.py \
--do_eval \ --do_eval \
--save_strategy epoch \ --save_strategy epoch \
--evaluation_strategy epoch \ --evaluation_strategy epoch \
--save_total_limit 3 \ --save_total_limit 1 \
--prediction_loss_only \ --prediction_loss_only \
--cache_dir ${cache_dir} \ --cache_dir ${cache_dir} \
--output_dir ${output_dir} \ --output_dir ${output_dir} \
......
n_gpus=1
task_name="dst"
dataset_name="sgd+tm1+tm2+tm3+multiwoz21"
speaker="user"
context_window_size=100
data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
metric_name_or_path="dst_metric.py"
metric_for_best_model="accuracy"
source_column="context"
target_column="state_seq"
truncation_side="left"
max_source_length=1024
max_target_length=512
model_name_or_path="t5-small"
per_device_train_batch_size=64
per_device_eval_batch_size=64
gradient_accumulation_steps=2
lr=1e-3
num_train_epochs=10
names=$(echo ${dataset_name} | tr "+" "\n")
rm -r ${data_dir}
mkdir -p ${data_dir}
for name in ${names};
do
echo "preprocessing ${name}"
# python ../create_data.py -t ${task_name} -d ${name} -s ${speaker} -c ${context_window_size}
done
python merge_data.py $(echo ${dataset_name} | tr "+" " ")
python ../run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--save_strategy epoch \
--evaluation_strategy epoch \
--save_total_limit 1 \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--adafactor \
--gradient_checkpointing
python ../run_seq2seq.py \
--task_name ${task_name} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${output_dir} \
--do_predict \
--predict_with_generate \
--metric_name_or_path ${metric_name_or_path} \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--adafactor \
--gradient_checkpointing
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
python ../../../dst/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment