diff --git a/convlab2/base_models/t5/nlg/merge_predict_res.py b/convlab2/base_models/t5/nlg/merge_predict_res.py new file mode 100755 index 0000000000000000000000000000000000000000..91e6055e13522caa5763c19b0d443b5244cc2496 --- /dev/null +++ b/convlab2/base_models/t5/nlg/merge_predict_res.py @@ -0,0 +1,33 @@ +import json +import os +from convlab2.util import load_dataset, load_nlg_data + + +def merge(dataset_name, speaker, save_dir, context_window_size, predict_result): + assert os.path.exists(predict_result) + dataset = load_dataset(dataset_name) + data = load_nlg_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test'] + + if save_dir is None: + save_dir = os.path.dirname(predict_result) + else: + os.makedirs(save_dir, exist_ok=True) + predict_result = [json.loads(x)['predictions'].strip() for x in open(predict_result)] + + for sample, prediction in zip(data, predict_result): + sample['predictions'] = {'utterance': prediction} + + json.dump(data, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + + +if __name__ == '__main__': + from argparse import ArgumentParser + parser = ArgumentParser(description="merge predict results with original data for unified NLU evaluation") + parser.add_argument('--dataset', '-d', metavar='dataset_name', type=str, help='name of the unified dataset') + parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances') + parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result') + parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered') + parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json') + args = parser.parse_args() + print(args) + merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result) diff --git a/convlab2/base_models/t5/nlg/run_nlg.sh b/convlab2/base_models/t5/nlg/run_nlg.sh new file mode 100644 index 0000000000000000000000000000000000000000..c9dc80842f38ed462e4b711d675f1445acefe1ca --- /dev/null +++ b/convlab2/base_models/t5/nlg/run_nlg.sh @@ -0,0 +1,79 @@ +n_gpus=1 +task_name="nlg" +dataset_name=$1 +speaker="system" +context_window_size=$2 +data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}" +output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}" +cache_dir="../cache" +logging_dir="${output_dir}/runs" +train_file="${data_dir}/train.json" +validation_file="${data_dir}/validation.json" +test_file="${data_dir}/test.json" +metric_name_or_path="nlg_metric.py" +metric_for_best_model="bleu" +source_column="context+da" +target_column="response" +truncation_side="right" +max_source_length=512 +max_target_length=512 +model_name_or_path="t5-small" +per_device_train_batch_size=128 +per_device_eval_batch_size=64 +gradient_accumulation_steps=4 +lr=1e-3 +num_train_epochs=10 + +python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} + +python ../run_seq2seq.py \ + --task_name ${task_name} \ + --train_file ${train_file} \ + --validation_file ${validation_file} \ + --source_column ${source_column} \ + --target_column ${target_column} \ + --max_source_length ${max_source_length} \ + --max_target_length ${max_target_length} \ + --truncation_side ${truncation_side} \ + --model_name_or_path ${model_name_or_path} \ + --do_train \ + --do_eval \ + --save_strategy epoch \ + --evaluation_strategy epoch \ + --prediction_loss_only \ + --cache_dir ${cache_dir} \ + --output_dir ${output_dir} \ + --logging_dir ${logging_dir} \ + --overwrite_output_dir \ + --preprocessing_num_workers 4 \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --learning_rate ${lr} \ + --num_train_epochs ${num_train_epochs} \ + --debug underflow_overflow \ + --adafactor \ + --gradient_checkpointing + +python ../run_seq2seq.py \ + --task_name ${task_name} \ + --test_file ${test_file} \ + --source_column ${source_column} \ + --target_column ${target_column} \ + --max_source_length ${max_source_length} \ + --max_target_length ${max_target_length} \ + --truncation_side ${truncation_side} \ + --model_name_or_path ${output_dir} \ + --do_predict \ + --predict_with_generate \ + --metric_name_or_path ${metric_name_or_path} \ + --cache_dir ${cache_dir} \ + --output_dir ${output_dir} \ + --logging_dir ${logging_dir} \ + --overwrite_output_dir \ + --preprocessing_num_workers 4 \ + --per_device_eval_batch_size ${per_device_eval_batch_size} + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json + +python ../../../nlg/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab2/nlg/evaluate_unified_datasets.py b/convlab2/nlg/evaluate_unified_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..c2256308c95f66189a36fdc0216dc29d602563d8 --- /dev/null +++ b/convlab2/nlg/evaluate_unified_datasets.py @@ -0,0 +1,27 @@ +import json +from pprint import pprint +import sacrebleu + + +def evaluate(predict_result): + predict_result = json.load(open(predict_result)) + + metrics = {} + predictions, references = [], [] + for sample in predict_result: + references.append(sample['utterance']) + predictions.append(sample['predictions']['utterance']) + + metrics['bleu'] = sacrebleu.corpus_bleu(predictions, [references], lowercase=True).score + + return metrics + + +if __name__ == '__main__': + from argparse import ArgumentParser + parser = ArgumentParser(description="calculate NLU metrics for unified datasets") + parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the prediction file that in the unified data format') + args = parser.parse_args() + print(args) + metrics = evaluate(args.predict_result) + pprint(metrics)