Skip to content
Snippets Groups Projects
Unverified Commit 57eacd6f authored by 罗崚骁's avatar 罗崚骁 Committed by GitHub
Browse files

update XLDST evaluation on CrossWOZ dataset (#153)

* ignore space in crosswoz evaluation

* update model evaluation

* eoln

* update error logging

* update eval

* update eval
parent 1363ecdc
No related branches found
No related tags found
No related merge requests found
......@@ -9,7 +9,7 @@ import importlib
from tqdm import tqdm
from convlab2.dst import DST
from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result
from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result, extract_gt
def evaluate(model_dir, subtask, test_data, gt):
......@@ -34,16 +34,28 @@ def evaluate(model_dir, subtask, test_data, gt):
dump_result(model_dir, 'model-result.json', result, errors, pred)
def eval_team(team):
for subtask in ['multiwoz', 'crosswoz']:
test_data = prepare_data(subtask, 'dstc9')
gt = extract_gt(test_data)
for i in range(1, 6):
model_dir = os.path.join(team, f'{subtask}-dst', f'submission{i}')
if not os.path.exists(model_dir):
continue
print(model_dir)
evaluate(model_dir, subtask, test_data, gt)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz'])
parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val'])
parser.add_argument('--teams', type=str, nargs='*')
args = parser.parse_args()
subtask = args.subtask
test_data = prepare_data(subtask, args.split)
gt = {
dialog_id: [state for _, _, state in turns]
for dialog_id, turns in test_data.items()
}
evaluate('example', subtask, test_data, gt)
if not args.teams:
for team in os.listdir('.'):
if not os.path.isdir(team):
continue
eval_team(team)
else:
for team in args.teams:
eval_team(team)
import os
import json
import os
import zipfile
from copy import deepcopy
from convlab2 import DATA_ROOT
......@@ -23,31 +22,41 @@ def prepare_data(subtask, split, data_root=DATA_ROOT):
for i in range(0, len(turns), 2):
sys_utt = turns[i - 1]['text'] if i else None
user_utt = turns[i]['text']
state = {}
dialog_state = {}
for domain_name, domain in turns[i + 1]['metadata'].items():
if domain_name in ['警察机关', '医院', '公共汽车']:
continue
domain_state = {}
state = {}
for slots in domain.values():
for slot_name, value in slots.items():
domain_state[slot_name] = value
state[domain_name] = domain_state
dialog_data.append((sys_utt, user_utt, state))
state[slot_name] = value
dialog_state[domain_name] = state
dialog_data.append((sys_utt, user_utt, dialog_state))
data[dialog_id] = dialog_data
else:
for dialog_id, dialog in test_data.items():
dialog_data = []
turns = dialog['messages']
selected_results = {k: [] for k in turns[1]['sys_state'].keys()}
for i in range(0, len(turns), 2):
sys_utt = turns[i - 1]['content'] if i else None
user_utt = turns[i]['content']
state = {}
for domain_name, domain_state in turns[i + 1]['sys_state_init'].items():
selected_results = domain_state.pop('selectedResults')
if selected_results and 'name' in domain_state and not domain_state['name']:
domain_state['name'] = selected_results
state[domain_name] = domain_state
dialog_data.append((sys_utt, user_utt, state))
dialog_state = {}
for domain_name, state in turns[i + 1]['sys_state_init'].items():
state.pop('selectedResults')
sys_selected_results = turns[i + 1]['sys_state'][domain_name].pop('selectedResults')
# if state has changed compared to previous sys state
state_change = i == 0 or state != turns[i - 1]['sys_state'][domain_name]
# clear the outdated previous selected results if state has been updated
if state_change:
selected_results[domain_name].clear()
if not state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1:
state['name'] = selected_results[domain_name][0]
dialog_state[domain_name] = state
if state_change and sys_selected_results:
selected_results[domain_name] = sys_selected_results
dialog_data.append((sys_utt, user_utt, dialog_state))
data[dialog_id] = dialog_data
return data
......@@ -63,12 +72,6 @@ def extract_gt(test_data):
# for unifying values with the same meaning to the same expression
def unify_value(value, subtask):
if isinstance(value, list):
ret = deepcopy(value)
for i, v in enumerate(ret):
ret[i] = unify_value(v, subtask)
return ret
value = value.lower()
value = {
'multiwoz': {
......@@ -79,6 +82,7 @@ def unify_value(value, subtask):
},
'crosswoz': {
'none': '',
'free admission': 'free',
}
}[subtask].get(value, value)
......@@ -94,7 +98,7 @@ def eval_states(gt, pred, subtask):
for k, v in kargs.items():
ret[k] = v
return ret, None
errors = []
errors = [['dialog id', 'turn id', 'domain name', 'slot name', 'ground truth', 'predict']]
joint_acc, joint_tot = 0, 0
slot_acc, slot_tot = 0, 0
......@@ -122,12 +126,12 @@ def eval_states(gt, pred, subtask):
pred_value = unify_value(pred_domain[slot_name], subtask)
slot_tot += 1
if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value:
if gt_value == pred_value:
slot_acc += 1
if gt_value:
tp += 1
else:
errors.append([gt_value, pred_value])
errors.append([dialog_id, turn_id, domain_name, slot_name, gt_value, pred_value])
turn_result = False
if gt_value:
fn += 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment