Skip to content
Snippets Groups Projects
Commit e0ee065c authored by function2's avatar function2
Browse files

add correct name labem argument

parent eb9d4011
No related branches found
No related tags found
No related merge requests found
...@@ -51,10 +51,11 @@ if __name__ == '__main__': ...@@ -51,10 +51,11 @@ if __name__ == '__main__':
from argparse import ArgumentParser from argparse import ArgumentParser
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz']) parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz'])
parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val']) parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val', 'dstc9-250'])
parser.add_argument('correct_name_label', action='store_true')
args = parser.parse_args() args = parser.parse_args()
subtask = args.subtask subtask = args.subtask
split = args.split split = args.split
dump_example(subtask, split) dump_example(subtask, split)
test_data = prepare_data(subtask, split) test_data = prepare_data(subtask, split, correct_name_label=args.correct_name_label)
gt = extract_gt(test_data) gt = extract_gt(test_data)
...@@ -34,9 +34,9 @@ def evaluate(model_dir, subtask, test_data, gt): ...@@ -34,9 +34,9 @@ def evaluate(model_dir, subtask, test_data, gt):
dump_result(model_dir, 'model-result.json', result, errors, pred) dump_result(model_dir, 'model-result.json', result, errors, pred)
def eval_team(team): def eval_team(team, correct_name_label):
for subtask in ['multiwoz', 'crosswoz']: for subtask in ['multiwoz', 'crosswoz']:
test_data = prepare_data(subtask, 'dstc9') test_data = prepare_data(subtask, 'dstc9', correct_name_label=correct_name_label)
gt = extract_gt(test_data) gt = extract_gt(test_data)
for i in range(1, 6): for i in range(1, 6):
model_dir = os.path.join(team, f'{subtask}-dst', f'submission{i}') model_dir = os.path.join(team, f'{subtask}-dst', f'submission{i}')
...@@ -50,12 +50,13 @@ if __name__ == '__main__': ...@@ -50,12 +50,13 @@ if __name__ == '__main__':
from argparse import ArgumentParser from argparse import ArgumentParser
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('--teams', type=str, nargs='*') parser.add_argument('--teams', type=str, nargs='*')
parser.add_argument('correct_name_label', action='store_true')
args = parser.parse_args() args = parser.parse_args()
if not args.teams: if not args.teams:
for team in os.listdir('.'): for team in os.listdir('.'):
if not os.path.isdir(team): if not os.path.isdir(team):
continue continue
eval_team(team) eval_team(team, args.correct_name_label)
else: else:
for team in args.teams: for team in args.teams:
eval_team(team) eval_team(team, args.correct_name_label)
...@@ -11,7 +11,7 @@ def get_subdir(subtask): ...@@ -11,7 +11,7 @@ def get_subdir(subtask):
return subdir return subdir
def prepare_data(subtask, split, data_root=DATA_ROOT): def prepare_data(subtask, split, data_root=DATA_ROOT, correct_name_label=False):
data_dir = os.path.join(data_root, get_subdir(subtask)) data_dir = os.path.join(data_root, get_subdir(subtask))
zip_filename = os.path.join(data_dir, f'{split}.json.zip') zip_filename = os.path.join(data_dir, f'{split}.json.zip')
test_data = json.load(zipfile.ZipFile(zip_filename).open(f'{split}.json')) test_data = json.load(zipfile.ZipFile(zip_filename).open(f'{split}.json'))
...@@ -38,11 +38,27 @@ def prepare_data(subtask, split, data_root=DATA_ROOT): ...@@ -38,11 +38,27 @@ def prepare_data(subtask, split, data_root=DATA_ROOT):
for dialog_id, dialog in test_data.items(): for dialog_id, dialog in test_data.items():
dialog_data = [] dialog_data = []
turns = dialog['messages'] turns = dialog['messages']
if correct_name_label:
selected_results = {domain_name: [] for domain_name in turns[1]['sys_state_init']}
for i in range(0, len(turns), 2): for i in range(0, len(turns), 2):
sys_utt = turns[i - 1]['content'] if i else None sys_utt = turns[i - 1]['content'] if i else None
user_utt = turns[i]['content'] user_utt = turns[i]['content']
dialog_state = {} dialog_state = {}
for domain_name, state in turns[i + 1]['sys_state_init'].items(): for domain_name, state in turns[i + 1]['sys_state_init'].items():
if correct_name_label:
state.pop('selectedResults')
sys_selected_results = turns[i + 1]['sys_state'][domain_name].pop('selectedResults')
# if state has changed compared to previous sys state
state_change = i == 0 or state != turns[i - 1]['sys_state'][domain_name]
# clear the outdated previous selected results if state has been updated
if state_change:
selected_results[domain_name].clear()
if not state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1:
state['name'] = selected_results[domain_name][0]
dialog_state[domain_name] = state
if state_change and sys_selected_results:
selected_results[domain_name] = sys_selected_results
else:
selected_results = state.pop('selectedResults') selected_results = state.pop('selectedResults')
if selected_results and 'name' in state and not state['name']: if selected_results and 'name' in state and not state['name']:
state['name'] = selected_results state['name'] = selected_results
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment