Skip to content
Snippets Groups Projects
Unverified Commit 42fb447c authored by 罗崚骁's avatar 罗崚骁 Committed by GitHub
Browse files

Merge pull request #148 from function2-llx/master

XLDST evaluation update
parents 4f9d5759 2c8dc0d9
No related branches found
No related tags found
No related merge requests found
......@@ -5,28 +5,28 @@
import json
import os
from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir
from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir, dump_result
def evaluate(model_dir, subtask, gt):
subdir = get_subdir(subtask)
results = {}
for i in range(1, 6):
filepath = os.path.join(model_dir, subdir, f'submission{i}.json')
filepath = os.path.join(model_dir, f'submission{i}.json')
if not os.path.exists(filepath):
continue
pred = json.load(open(filepath))
results[i] = eval_states(gt, pred, subtask)
print(json.dumps(results, indent=4))
json.dump(results, open(os.path.join(model_dir, subdir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False)
result, errors = eval_states(gt, pred, subtask)
print(json.dumps(result, indent=4, ensure_ascii=False))
dump_result(model_dir, 'file-result.json', result)
return
raise ValueError('submission file not found')
# generate submission examples
def dump_example(subtask, split):
test_data = prepare_data(subtask, split)
pred = extract_gt(test_data)
json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission1.json'), 'w'), ensure_ascii=False, indent=4)
subdir = get_subdir(subtask)
json.dump(pred, open(os.path.join('example', subdir, 'submission1.json'), 'w'), ensure_ascii=False, indent=4)
import random
for dialog_id, states in pred.items():
for state in states:
......@@ -38,13 +38,13 @@ def dump_example(subtask, split):
else:
if random.randint(0, 4) == 0:
domain[slot] = "2333"
json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission2.json'), 'w'), ensure_ascii=False, indent=4)
json.dump(pred, open(os.path.join('example', subdir, 'submission2.json'), 'w'), ensure_ascii=False, indent=4)
for dialog_id, states in pred.items():
for state in states:
for domain in state.values():
for slot in domain:
domain[slot] = ""
json.dump(pred, open(os.path.join('example', get_subdir(subtask), 'submission3.json'), 'w'), ensure_ascii=False, indent=4)
json.dump(pred, open(os.path.join('example', subdir, 'submission3.json'), 'w'), ensure_ascii=False, indent=4)
if __name__ == '__main__':
......@@ -58,4 +58,3 @@ if __name__ == '__main__':
dump_example(subtask, split)
test_data = prepare_data(subtask, split)
gt = extract_gt(test_data)
evaluate('example', subtask, gt)
......@@ -6,25 +6,32 @@ import os
import json
import importlib
from tqdm import tqdm
from convlab2.dst import DST
from convlab2.dst.dstc9.utils import prepare_data, eval_states, get_subdir
from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result
def evaluate(model_dir, subtask, test_data, gt):
subdir = get_subdir(subtask)
module = importlib.import_module(f'{model_dir}.{subdir}')
module = importlib.import_module(model_dir.replace('/', '.'))
assert 'Model' in dir(module), 'please import your model as name `Model` in your subtask module root'
model_cls = module.__getattribute__('Model')
model_cls = getattr(module, 'Model')
assert issubclass(model_cls, DST), 'the model must implement DST interface'
# load weights, set eval() on default
model = model_cls()
pred = {}
bar = tqdm(total=sum(len(turns) for turns in test_data.values()), ncols=80, desc='evaluating')
for dialog_id, turns in test_data.items():
model.init_session()
pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns]
result = eval_states(gt, pred, subtask)
pred[dialog_id] = []
for sys_utt, user_utt, gt_turn in turns:
pred[dialog_id].append(model.update_turn(sys_utt, user_utt))
bar.update()
bar.close()
result, errors = eval_states(gt, pred, subtask)
print(json.dumps(result, indent=4))
json.dump(result, open(os.path.join(model_dir, subdir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False)
dump_result(model_dir, 'model-result.json', result, errors, pred)
if __name__ == '__main__':
......
......@@ -25,7 +25,7 @@ def prepare_data(subtask, split, data_root=DATA_ROOT):
user_utt = turns[i]['text']
state = {}
for domain_name, domain in turns[i + 1]['metadata'].items():
if domain_name in ['警察机关', '医院']:
if domain_name in ['警察机关', '医院', '公共汽车']:
continue
domain_state = {}
for slots in domain.values():
......@@ -79,7 +79,7 @@ def unify_value(value, subtask):
'crosswoz': {
'None': '',
}
}[subtask].get(value, value)
}[subtask].get(value, value).lower()
return ' '.join(value.strip().split())
......@@ -92,7 +92,8 @@ def eval_states(gt, pred, subtask):
}
for k, v in kargs.items():
ret[k] = v
return ret
return ret, None
errors = []
joint_acc, joint_tot = 0, 0
slot_acc, slot_tot = 0, 0
......@@ -119,11 +120,13 @@ def eval_states(gt, pred, subtask):
gt_value = unify_value(gt_value, subtask)
pred_value = unify_value(pred_domain[slot_name], subtask)
slot_tot += 1
if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value:
slot_acc += 1
if gt_value:
tp += 1
else:
errors.append([gt_value, pred_value])
turn_result = False
if gt_value:
fn += 1
......@@ -143,4 +146,17 @@ def eval_states(gt, pred, subtask):
'recall': recall,
'f1': f1,
}
}
}, errors
def dump_result(model_dir, filename, result, errors=None, pred=None):
output_dir = os.path.join('../results', model_dir)
os.makedirs(output_dir, exist_ok=True)
json.dump(result, open(os.path.join(output_dir, filename), 'w'), indent=4, ensure_ascii=False)
if errors:
import csv
with open(os.path.join(output_dir, 'errors.csv'), 'w') as f:
writer = csv.writer(f)
writer.writerows(errors)
if pred:
json.dump(pred, open(os.path.join(output_dir, 'pred.json'), 'w'), indent=4, ensure_ascii=False)
*/dstc9*.json.zip
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment