Skip to content
Snippets Groups Projects
Unverified Commit 57eacd6f authored by 罗崚骁's avatar 罗崚骁 Committed by GitHub
Browse files

update XLDST evaluation on CrossWOZ dataset (#153)

* ignore space in crosswoz evaluation

* update model evaluation

* eoln

* update error logging

* update eval

* update eval
parent 1363ecdc
No related branches found
No related tags found
No related merge requests found
...@@ -9,7 +9,7 @@ import importlib ...@@ -9,7 +9,7 @@ import importlib
from tqdm import tqdm from tqdm import tqdm
from convlab2.dst import DST from convlab2.dst import DST
from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result, extract_gt
def evaluate(model_dir, subtask, test_data, gt): def evaluate(model_dir, subtask, test_data, gt):
...@@ -34,16 +34,28 @@ def evaluate(model_dir, subtask, test_data, gt): ...@@ -34,16 +34,28 @@ def evaluate(model_dir, subtask, test_data, gt):
dump_result(model_dir, 'model-result.json', result, errors, pred) dump_result(model_dir, 'model-result.json', result, errors, pred)
def eval_team(team):
for subtask in ['multiwoz', 'crosswoz']:
test_data = prepare_data(subtask, 'dstc9')
gt = extract_gt(test_data)
for i in range(1, 6):
model_dir = os.path.join(team, f'{subtask}-dst', f'submission{i}')
if not os.path.exists(model_dir):
continue
print(model_dir)
evaluate(model_dir, subtask, test_data, gt)
if __name__ == '__main__': if __name__ == '__main__':
from argparse import ArgumentParser from argparse import ArgumentParser
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz']) parser.add_argument('--teams', type=str, nargs='*')
parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val'])
args = parser.parse_args() args = parser.parse_args()
subtask = args.subtask if not args.teams:
test_data = prepare_data(subtask, args.split) for team in os.listdir('.'):
gt = { if not os.path.isdir(team):
dialog_id: [state for _, _, state in turns] continue
for dialog_id, turns in test_data.items() eval_team(team)
} else:
evaluate('example', subtask, test_data, gt) for team in args.teams:
eval_team(team)
import os
import json import json
import os
import zipfile import zipfile
from copy import deepcopy
from convlab2 import DATA_ROOT from convlab2 import DATA_ROOT
...@@ -23,31 +22,41 @@ def prepare_data(subtask, split, data_root=DATA_ROOT): ...@@ -23,31 +22,41 @@ def prepare_data(subtask, split, data_root=DATA_ROOT):
for i in range(0, len(turns), 2): for i in range(0, len(turns), 2):
sys_utt = turns[i - 1]['text'] if i else None sys_utt = turns[i - 1]['text'] if i else None
user_utt = turns[i]['text'] user_utt = turns[i]['text']
state = {} dialog_state = {}
for domain_name, domain in turns[i + 1]['metadata'].items(): for domain_name, domain in turns[i + 1]['metadata'].items():
if domain_name in ['警察机关', '医院', '公共汽车']: if domain_name in ['警察机关', '医院', '公共汽车']:
continue continue
domain_state = {} state = {}
for slots in domain.values(): for slots in domain.values():
for slot_name, value in slots.items(): for slot_name, value in slots.items():
domain_state[slot_name] = value state[slot_name] = value
state[domain_name] = domain_state dialog_state[domain_name] = state
dialog_data.append((sys_utt, user_utt, state)) dialog_data.append((sys_utt, user_utt, dialog_state))
data[dialog_id] = dialog_data data[dialog_id] = dialog_data
else: else:
for dialog_id, dialog in test_data.items(): for dialog_id, dialog in test_data.items():
dialog_data = [] dialog_data = []
turns = dialog['messages'] turns = dialog['messages']
selected_results = {k: [] for k in turns[1]['sys_state'].keys()}
for i in range(0, len(turns), 2): for i in range(0, len(turns), 2):
sys_utt = turns[i - 1]['content'] if i else None sys_utt = turns[i - 1]['content'] if i else None
user_utt = turns[i]['content'] user_utt = turns[i]['content']
state = {} dialog_state = {}
for domain_name, domain_state in turns[i + 1]['sys_state_init'].items(): for domain_name, state in turns[i + 1]['sys_state_init'].items():
selected_results = domain_state.pop('selectedResults') state.pop('selectedResults')
if selected_results and 'name' in domain_state and not domain_state['name']: sys_selected_results = turns[i + 1]['sys_state'][domain_name].pop('selectedResults')
domain_state['name'] = selected_results # if state has changed compared to previous sys state
state[domain_name] = domain_state state_change = i == 0 or state != turns[i - 1]['sys_state'][domain_name]
dialog_data.append((sys_utt, user_utt, state)) # clear the outdated previous selected results if state has been updated
if state_change:
selected_results[domain_name].clear()
if not state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1:
state['name'] = selected_results[domain_name][0]
dialog_state[domain_name] = state
if state_change and sys_selected_results:
selected_results[domain_name] = sys_selected_results
dialog_data.append((sys_utt, user_utt, dialog_state))
data[dialog_id] = dialog_data data[dialog_id] = dialog_data
return data return data
...@@ -63,12 +72,6 @@ def extract_gt(test_data): ...@@ -63,12 +72,6 @@ def extract_gt(test_data):
# for unifying values with the same meaning to the same expression # for unifying values with the same meaning to the same expression
def unify_value(value, subtask): def unify_value(value, subtask):
if isinstance(value, list):
ret = deepcopy(value)
for i, v in enumerate(ret):
ret[i] = unify_value(v, subtask)
return ret
value = value.lower() value = value.lower()
value = { value = {
'multiwoz': { 'multiwoz': {
...@@ -79,6 +82,7 @@ def unify_value(value, subtask): ...@@ -79,6 +82,7 @@ def unify_value(value, subtask):
}, },
'crosswoz': { 'crosswoz': {
'none': '', 'none': '',
'free admission': 'free',
} }
}[subtask].get(value, value) }[subtask].get(value, value)
...@@ -94,7 +98,7 @@ def eval_states(gt, pred, subtask): ...@@ -94,7 +98,7 @@ def eval_states(gt, pred, subtask):
for k, v in kargs.items(): for k, v in kargs.items():
ret[k] = v ret[k] = v
return ret, None return ret, None
errors = [] errors = [['dialog id', 'turn id', 'domain name', 'slot name', 'ground truth', 'predict']]
joint_acc, joint_tot = 0, 0 joint_acc, joint_tot = 0, 0
slot_acc, slot_tot = 0, 0 slot_acc, slot_tot = 0, 0
...@@ -122,12 +126,12 @@ def eval_states(gt, pred, subtask): ...@@ -122,12 +126,12 @@ def eval_states(gt, pred, subtask):
pred_value = unify_value(pred_domain[slot_name], subtask) pred_value = unify_value(pred_domain[slot_name], subtask)
slot_tot += 1 slot_tot += 1
if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value: if gt_value == pred_value:
slot_acc += 1 slot_acc += 1
if gt_value: if gt_value:
tp += 1 tp += 1
else: else:
errors.append([gt_value, pred_value]) errors.append([dialog_id, turn_id, domain_name, slot_name, gt_value, pred_value])
turn_result = False turn_result = False
if gt_value: if gt_value:
fn += 1 fn += 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment