Skip to content
Snippets Groups Projects
Commit 2d5229ef authored by function2's avatar function2
Browse files

update precision, recall, f1 calculation

parent 54c57c59
No related branches found
No related tags found
No related merge requests found
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import os import os
import json import json
from copy import deepcopy
from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir
...@@ -16,21 +17,35 @@ def evaluate(model_dir, subtask, gt): ...@@ -16,21 +17,35 @@ def evaluate(model_dir, subtask, gt):
if not os.path.exists(filepath): if not os.path.exists(filepath):
continue continue
pred = json.load(open(filepath)) pred = json.load(open(filepath))
results[i] = eval_states(gt, pred) results[i] = eval_states(gt, pred, subtask)
print(json.dumps(results, indent=4))
json.dump(results, open(os.path.join(model_dir, subdir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False) json.dump(results, open(os.path.join(model_dir, subdir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False)
# generate submission examples
def dump_example(subtask, split): def dump_example(subtask, split):
test_data = prepare_data(subtask, split) test_data = prepare_data(subtask, split)
gt = extract_gt(test_data) pred = extract_gt(test_data)
json.dump(gt, open(os.path.join('example', get_subdir(subtask), 'submission1.json'), 'w'), ensure_ascii=False, indent=4) json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission1.json'), 'w'), ensure_ascii=False, indent=4)
for dialog_id, states in gt.items(): import random
for dialog_id, states in pred.items():
for state in states:
for domain in state.values():
for slot, value in domain.items():
if value:
if random.randint(0, 2) == 0:
domain[slot] = ""
else:
if random.randint(0, 4) == 0:
domain[slot] = "2333"
json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission2.json'), 'w'), ensure_ascii=False, indent=4)
for dialog_id, states in pred.items():
for state in states: for state in states:
for domain in state.values(): for domain in state.values():
for slot in domain: for slot in domain:
domain[slot] = "" domain[slot] = ""
json.dump(gt, open(os.path.join('example', get_subdir(subtask), 'submission2.json'), 'w'), ensure_ascii=False, indent=4) json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission3.json'), 'w'), ensure_ascii=False, indent=4)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -22,8 +22,8 @@ def evaluate(model_dir, subtask, test_data, gt): ...@@ -22,8 +22,8 @@ def evaluate(model_dir, subtask, test_data, gt):
for dialog_id, turns in test_data.items(): for dialog_id, turns in test_data.items():
model.init_session() model.init_session()
pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns] pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns]
result = eval_states(gt, pred) result = eval_states(gt, pred, subtask)
print(result) print(json.dumps(result, indent=4))
json.dump(result, open(os.path.join(model_dir, subdir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False) json.dump(result, open(os.path.join(model_dir, subdir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False)
......
...@@ -57,7 +57,20 @@ def extract_gt(test_data): ...@@ -57,7 +57,20 @@ def extract_gt(test_data):
return gt return gt
def eval_states(gt, pred): def eval_states(gt, pred, subtask):
# for unifying values with the same meaning to the same expression
value_unifier = {
'multiwoz': {
},
'crosswoz': {
'未提及': '',
}
}[subtask]
def unify_value(value):
return value_unifier.get(value, value)
def exception(description, **kargs): def exception(description, **kargs):
ret = { ret = {
'status': 'exception', 'status': 'exception',
...@@ -89,32 +102,34 @@ def eval_states(gt, pred): ...@@ -89,32 +102,34 @@ def eval_states(gt, pred):
for slot_name, gt_value in gt_domain.items(): for slot_name, gt_value in gt_domain.items():
if slot_name not in pred_domain: if slot_name not in pred_domain:
return exception('slot missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name, slot=slot_name) return exception('slot missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name, slot=slot_name)
pred_value = pred_domain[slot_name] gt_value = unify_value(gt_value)
pred_value = unify_value(pred_domain[slot_name])
slot_tot += 1 slot_tot += 1
if gt_value == pred_value: if gt_value == pred_value:
slot_acc += 1 slot_acc += 1
if gt_value:
tp += 1 tp += 1
else: else:
turn_result = False turn_result = False
# for class of gt_value if gt_value:
fn += 1 fn += 1
# for class of pred_value if pred_value:
fp += 1 fp += 1
joint_acc += turn_result joint_acc += turn_result
precision = tp / (tp + fp) precision = tp / (tp + fp) if tp + fp else 1
recall = tp / (tp + fn) recall = tp / (tp + fn) if tp + fn else 1
f1 = 2 * tp / (2 * tp + fp + fn) f1 = 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) else 1
return { return {
'status': 'ok', 'status': 'ok',
'joint accuracy': joint_acc / joint_tot, 'joint accuracy': joint_acc / joint_tot,
'slot accuracy': slot_acc / slot_tot, # 'slot accuracy': slot_acc / slot_tot,
# 'slot': { 'slot': {
# 'accuracy': slot_acc / slot_tot, 'accuracy': slot_acc / slot_tot,
# 'precision': precision, 'precision': precision,
# 'recall': recall, 'recall': recall,
# 'f1': f1, 'f1': f1,
# } }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment