Skip to content
Snippets Groups Projects
Unverified Commit d70c0fc6 authored by 罗崚骁's avatar 罗崚骁 Committed by GitHub
Browse files

add metrics in XLDST evaluation (#126)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* update precision, recall, f1 calculation

* minor change
parent e8ae8881
No related branches found
No related tags found
No related merge requests found
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import os import os
import json import json
from copy import deepcopy
from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir
...@@ -16,21 +17,35 @@ def evaluate(model_dir, subtask, gt): ...@@ -16,21 +17,35 @@ def evaluate(model_dir, subtask, gt):
if not os.path.exists(filepath): if not os.path.exists(filepath):
continue continue
pred = json.load(open(filepath)) pred = json.load(open(filepath))
results[i] = eval_states(gt, pred) results[i] = eval_states(gt, pred, subtask)
print(json.dumps(results, indent=4))
json.dump(results, open(os.path.join(model_dir, subdir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False) json.dump(results, open(os.path.join(model_dir, subdir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False)
# generate submission examples
def dump_example(subtask, split): def dump_example(subtask, split):
test_data = prepare_data(subtask, split) test_data = prepare_data(subtask, split)
gt = extract_gt(test_data) pred = extract_gt(test_data)
json.dump(gt, open(os.path.join('example', get_subdir(subtask), 'submission1.json'), 'w'), ensure_ascii=False, indent=4) json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission1.json'), 'w'), ensure_ascii=False, indent=4)
for dialog_id, states in gt.items(): import random
for dialog_id, states in pred.items():
for state in states:
for domain in state.values():
for slot, value in domain.items():
if value:
if random.randint(0, 2) == 0:
domain[slot] = ""
else:
if random.randint(0, 4) == 0:
domain[slot] = "2333"
json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission2.json'), 'w'), ensure_ascii=False, indent=4)
for dialog_id, states in pred.items():
for state in states: for state in states:
for domain in state.values(): for domain in state.values():
for slot in domain: for slot in domain:
domain[slot] = "" domain[slot] = ""
json.dump(gt, open(os.path.join('example', get_subdir(subtask), 'submission2.json'), 'w'), ensure_ascii=False, indent=4) json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission3.json'), 'w'), ensure_ascii=False, indent=4)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -22,8 +22,8 @@ def evaluate(model_dir, subtask, test_data, gt): ...@@ -22,8 +22,8 @@ def evaluate(model_dir, subtask, test_data, gt):
for dialog_id, turns in test_data.items(): for dialog_id, turns in test_data.items():
model.init_session() model.init_session()
pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns] pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns]
result = eval_states(gt, pred) result = eval_states(gt, pred, subtask)
print(result) print(json.dumps(result, indent=4))
json.dump(result, open(os.path.join(model_dir, subdir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False) json.dump(result, open(os.path.join(model_dir, subdir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False)
......
...@@ -5,8 +5,13 @@ import zipfile ...@@ -5,8 +5,13 @@ import zipfile
from convlab2 import DATA_ROOT from convlab2 import DATA_ROOT
def get_subdir(subtask):
subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en'
return subdir
def prepare_data(subtask, split, data_root=DATA_ROOT): def prepare_data(subtask, split, data_root=DATA_ROOT):
data_dir = os.path.join(data_root, 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en') data_dir = os.path.join(data_root, get_subdir(subtask))
zip_filename = os.path.join(data_dir, f'{split}.json.zip') zip_filename = os.path.join(data_dir, f'{split}.json.zip')
test_data = json.load(zipfile.ZipFile(zip_filename).open(f'{split}.json')) test_data = json.load(zipfile.ZipFile(zip_filename).open(f'{split}.json'))
data = {} data = {}
...@@ -57,7 +62,20 @@ def extract_gt(test_data): ...@@ -57,7 +62,20 @@ def extract_gt(test_data):
return gt return gt
def eval_states(gt, pred): def eval_states(gt, pred, subtask):
# for unifying values with the same meaning to the same expression
value_unifier = {
'multiwoz': {
},
'crosswoz': {
'未提及': '',
}
}[subtask]
def unify_value(value):
return value_unifier.get(value, value)
def exception(description, **kargs): def exception(description, **kargs):
ret = { ret = {
'status': 'exception', 'status': 'exception',
...@@ -89,35 +107,32 @@ def eval_states(gt, pred): ...@@ -89,35 +107,32 @@ def eval_states(gt, pred):
for slot_name, gt_value in gt_domain.items(): for slot_name, gt_value in gt_domain.items():
if slot_name not in pred_domain: if slot_name not in pred_domain:
return exception('slot missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name, slot=slot_name) return exception('slot missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name, slot=slot_name)
pred_value = pred_domain[slot_name] gt_value = unify_value(gt_value)
pred_value = unify_value(pred_domain[slot_name])
slot_tot += 1 slot_tot += 1
if gt_value == pred_value: if gt_value == pred_value:
slot_acc += 1 slot_acc += 1
if gt_value:
tp += 1 tp += 1
else: else:
turn_result = False turn_result = False
# for class of gt_value if gt_value:
fn += 1 fn += 1
# for class of pred_value if pred_value:
fp += 1 fp += 1
joint_acc += turn_result joint_acc += turn_result
precision = tp / (tp + fp) precision = tp / (tp + fp) if tp + fp else 1
recall = tp / (tp + fn) recall = tp / (tp + fn) if tp + fn else 1
f1 = 2 * tp / (2 * tp + fp + fn) f1 = 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) else 1
return { return {
'status': 'ok', 'status': 'ok',
'joint accuracy': joint_acc / joint_tot, 'joint accuracy': joint_acc / joint_tot,
'slot accuracy': slot_acc / slot_tot, # 'slot accuracy': slot_acc / slot_tot,
# 'slot': { 'slot': {
# 'accuracy': slot_acc / slot_tot, 'accuracy': slot_acc / slot_tot,
# 'precision': precision, 'precision': precision,
# 'recall': recall, 'recall': recall,
# 'f1': f1, 'f1': f1,
# } }
} }
def get_subdir(subtask):
subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en'
return subdir
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment