Skip to content
Snippets Groups Projects
Unverified Commit 6e16334a authored by zhuqi's avatar zhuqi Committed by GitHub
Browse files

Merge pull request #49 from ConvLab/unified_dataset

add dst unified evaluation script
parents 2a5865c5 2a4e214b
No related branches found
No related tags found
No related merge requests found
......@@ -80,3 +80,5 @@ python -m torch.distributed.launch \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
python ../../../dst/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
......@@ -64,3 +64,5 @@ python -m torch.distributed.launch \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
......@@ -64,3 +64,5 @@ python -m torch.distributed.launch \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
......@@ -64,3 +64,5 @@ python -m torch.distributed.launch \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
......@@ -64,3 +64,5 @@ python -m torch.distributed.launch \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
......@@ -64,3 +64,5 @@ python -m torch.distributed.launch \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
......@@ -64,3 +64,5 @@ python -m torch.distributed.launch \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
......@@ -64,3 +64,5 @@ python -m torch.distributed.launch \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
......@@ -64,3 +64,5 @@ python -m torch.distributed.launch \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
import json
from pprint import pprint
def evaluate(predict_result):
predict_result = json.load(open(predict_result))
metrics = {'TP':0, 'FP':0, 'FN':0}
acc = []
for sample in predict_result:
pred_state = sample['predictions']['state']
gold_state = sample['state']
predicts = sorted(list({(domain, slot, ''.join(value.split()).lower()) for domain in pred_state for slot, value in pred_state[domain].items() if len(value)>0}))
labels = sorted(list({(domain, slot, ''.join(value.split()).lower()) for domain in gold_state for slot, value in gold_state[domain].items() if len(value)>0}))
flag = True
for ele in predicts:
if ele in labels:
metrics['TP'] += 1
else:
metrics['FP'] += 1
for ele in labels:
if ele not in predicts:
metrics['FN'] += 1
flag &= (predicts==labels)
acc.append(flag)
TP = metrics.pop('TP')
FP = metrics.pop('FP')
FN = metrics.pop('FN')
precision = 1.0 * TP / (TP + FP) if TP + FP else 0.
recall = 1.0 * TP / (TP + FN) if TP + FN else 0.
f1 = 2.0 * precision * recall / (precision + recall) if precision + recall else 0.
metrics[f'slot_f1'] = f1
metrics[f'slot_precision'] = precision
metrics[f'slot_recall'] = recall
metrics['accuracy'] = sum(acc)/len(acc)
return metrics
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="calculate DST metrics for unified datasets")
parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the prediction file that in the unified data format')
args = parser.parse_args()
print(args)
metrics = evaluate(args.predict_result)
pprint(metrics)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment