diff --git a/convlab2/base_models/t5/dst/run_multiwoz21.sh b/convlab2/base_models/t5/dst/run_multiwoz21.sh index e7573e951391209e2e58ea7c3031ee56cc11843b..6d42380be583927f5d1ef8c148bd576d4ac061f8 100644 --- a/convlab2/base_models/t5/dst/run_multiwoz21.sh +++ b/convlab2/base_models/t5/dst/run_multiwoz21.sh @@ -80,3 +80,5 @@ python -m torch.distributed.launch \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json + +python ../../../dst/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab2/base_models/t5/nlu/run_multiwoz21_user.sh b/convlab2/base_models/t5/nlu/run_multiwoz21_user.sh index 4080a09249fbed8260ed8e0b0fab7961d9d9120a..a9e9d6c55b38bfa1b43b0a837e4eef9d60a7e233 100644 --- a/convlab2/base_models/t5/nlu/run_multiwoz21_user.sh +++ b/convlab2/base_models/t5/nlu/run_multiwoz21_user.sh @@ -64,3 +64,5 @@ python -m torch.distributed.launch \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json + +python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab2/base_models/t5/nlu/run_multiwoz21_user_context3.sh b/convlab2/base_models/t5/nlu/run_multiwoz21_user_context3.sh index a8cbc9bd6c30b91a16c1e278450161992d8dca9a..e90e71d459da3bd43eeccc82d7ad192d7f751996 100644 --- a/convlab2/base_models/t5/nlu/run_multiwoz21_user_context3.sh +++ b/convlab2/base_models/t5/nlu/run_multiwoz21_user_context3.sh @@ -64,3 +64,5 @@ python -m torch.distributed.launch \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json + +python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab2/base_models/t5/nlu/run_tm1_user.sh b/convlab2/base_models/t5/nlu/run_tm1_user.sh index 9faedd7f664fb40ba20883033aaa1d3817d66c1c..5372441aa284067f137e892454b8fed76e251e6e 100644 --- a/convlab2/base_models/t5/nlu/run_tm1_user.sh +++ b/convlab2/base_models/t5/nlu/run_tm1_user.sh @@ -64,3 +64,5 @@ python -m torch.distributed.launch \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json + +python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab2/base_models/t5/nlu/run_tm1_user_context3.sh b/convlab2/base_models/t5/nlu/run_tm1_user_context3.sh index bb6b55fe06c54bab7294a55d8abda30e959acf34..65482a1f517b7c1eb8607e0858d3ae576d1483b4 100644 --- a/convlab2/base_models/t5/nlu/run_tm1_user_context3.sh +++ b/convlab2/base_models/t5/nlu/run_tm1_user_context3.sh @@ -64,3 +64,5 @@ python -m torch.distributed.launch \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json + +python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab2/base_models/t5/nlu/run_tm2_user.sh b/convlab2/base_models/t5/nlu/run_tm2_user.sh index 728a8a94748c8344104fb9176fd8d2599580b11d..84dc8b71ae560dcb481dee8bcfc31340ed4a778d 100644 --- a/convlab2/base_models/t5/nlu/run_tm2_user.sh +++ b/convlab2/base_models/t5/nlu/run_tm2_user.sh @@ -64,3 +64,5 @@ python -m torch.distributed.launch \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json + +python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab2/base_models/t5/nlu/run_tm2_user_context3.sh b/convlab2/base_models/t5/nlu/run_tm2_user_context3.sh index 8ebb102dd99c22a9e6dc752c09b48b1538c77ad8..abca0a60a98c1a71295ae1ad77791dec6b482547 100644 --- a/convlab2/base_models/t5/nlu/run_tm2_user_context3.sh +++ b/convlab2/base_models/t5/nlu/run_tm2_user_context3.sh @@ -64,3 +64,5 @@ python -m torch.distributed.launch \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json + +python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab2/base_models/t5/nlu/run_tm3_user.sh b/convlab2/base_models/t5/nlu/run_tm3_user.sh index 0d775f7ae41a63f72fc93539186b63aa2b4a551f..689a626c9a871581b49eb84a7db1e9af1152e32e 100644 --- a/convlab2/base_models/t5/nlu/run_tm3_user.sh +++ b/convlab2/base_models/t5/nlu/run_tm3_user.sh @@ -64,3 +64,5 @@ python -m torch.distributed.launch \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json + +python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab2/base_models/t5/nlu/run_tm3_user_context3.sh b/convlab2/base_models/t5/nlu/run_tm3_user_context3.sh index c8a4a9f6b5e77ee6e05ae36aea2e002202243c72..e2ded66dbe940387f1997ba92028ef4dc4a5b5c5 100644 --- a/convlab2/base_models/t5/nlu/run_tm3_user_context3.sh +++ b/convlab2/base_models/t5/nlu/run_tm3_user_context3.sh @@ -64,3 +64,5 @@ python -m torch.distributed.launch \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json + +python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab2/dst/evaluate_unified_datasets.py b/convlab2/dst/evaluate_unified_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..78d0db4adf2ff1a84ae2b9175b5d8744628fa57b --- /dev/null +++ b/convlab2/dst/evaluate_unified_datasets.py @@ -0,0 +1,50 @@ +import json +from pprint import pprint + + +def evaluate(predict_result): + predict_result = json.load(open(predict_result)) + + metrics = {'TP':0, 'FP':0, 'FN':0} + acc = [] + + for sample in predict_result: + pred_state = sample['predictions']['state'] + gold_state = sample['state'] + predicts = sorted(list({(domain, slot, ''.join(value.split()).lower()) for domain in pred_state for slot, value in pred_state[domain].items() if len(value)>0})) + labels = sorted(list({(domain, slot, ''.join(value.split()).lower()) for domain in gold_state for slot, value in gold_state[domain].items() if len(value)>0})) + + flag = True + for ele in predicts: + if ele in labels: + metrics['TP'] += 1 + else: + metrics['FP'] += 1 + for ele in labels: + if ele not in predicts: + metrics['FN'] += 1 + flag &= (predicts==labels) + acc.append(flag) + + TP = metrics.pop('TP') + FP = metrics.pop('FP') + FN = metrics.pop('FN') + precision = 1.0 * TP / (TP + FP) if TP + FP else 0. + recall = 1.0 * TP / (TP + FN) if TP + FN else 0. + f1 = 2.0 * precision * recall / (precision + recall) if precision + recall else 0. + metrics[f'slot_f1'] = f1 + metrics[f'slot_precision'] = precision + metrics[f'slot_recall'] = recall + metrics['accuracy'] = sum(acc)/len(acc) + + return metrics + + +if __name__ == '__main__': + from argparse import ArgumentParser + parser = ArgumentParser(description="calculate DST metrics for unified datasets") + parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the prediction file that in the unified data format') + args = parser.parse_args() + print(args) + metrics = evaluate(args.predict_result) + pprint(metrics)