Skip to content
Snippets Groups Projects
Unverified Commit 0d45aa65 authored by 罗崚骁's avatar 罗崚骁 Committed by GitHub
Browse files

fix XLDST evaluation (#141)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* update precision, recall, f1 calculation

* minor change

* fix a database typo

* use selectedResults for missing name

* add value unification
parent 3ee2c5c4
No related branches found
No related tags found
No related merge requests found
......@@ -2,9 +2,8 @@
evaluate output file
"""
import os
import json
from copy import deepcopy
import os
from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir
......
import os
import json
import zipfile
from copy import deepcopy
from convlab2 import DATA_ROOT
......@@ -41,12 +42,10 @@ def prepare_data(subtask, split, data_root=DATA_ROOT):
sys_utt = turns[i - 1]['content'] if i else None
user_utt = turns[i]['content']
state = {}
for domain_name, domain in turns[i + 1]['sys_state_init'].items():
domain_state = {}
for slot_name, value in domain.items():
if slot_name == 'selectedResults':
continue
domain_state[slot_name] = value
for domain_name, domain_state in turns[i + 1]['sys_state_init'].items():
selected_results = domain_state.pop('selectedResults')
if selected_results and 'name' in domain_state and not domain_state['name']:
domain_state['name'] = selected_results
state[domain_name] = domain_state
dialog_data.append((sys_utt, user_utt, state))
data[dialog_id] = dialog_data
......@@ -62,20 +61,28 @@ def extract_gt(test_data):
return gt
def eval_states(gt, pred, subtask):
# for unifying values with the same meaning to the same expression
value_unifier = {
'multiwoz': {
def unify_value(value, subtask):
if isinstance(value, list):
ret = deepcopy(value)
for i, v in enumerate(ret):
ret[i] = unify_value(v, subtask)
return ret
return {
'multiwoz': {
'未提及': '',
'none': '',
'是的': '',
'不是': '没有',
},
'crosswoz': {
'未提及': '',
'None': '',
}
}[subtask]
}[subtask].get(value, value)
def unify_value(value):
return value_unifier.get(value, value)
def eval_states(gt, pred, subtask):
def exception(description, **kargs):
ret = {
'status': 'exception',
......@@ -107,10 +114,10 @@ def eval_states(gt, pred, subtask):
for slot_name, gt_value in gt_domain.items():
if slot_name not in pred_domain:
return exception('slot missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name, slot=slot_name)
gt_value = unify_value(gt_value)
pred_value = unify_value(pred_domain[slot_name])
gt_value = unify_value(gt_value, subtask)
pred_value = unify_value(pred_domain[slot_name], subtask)
slot_tot += 1
if gt_value == pred_value:
if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value:
slot_acc += 1
if gt_value:
tp += 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment