diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py index bad4ee79a4ff284efe5a5613b00d59938658a8ce..aa42cc84ea82bc6635bd7348406cb1b42d2ca799 100644 --- a/convlab2/dst/dstc9/eval_model.py +++ b/convlab2/dst/dstc9/eval_model.py @@ -6,6 +6,8 @@ import os import json import importlib +from tqdm import tqdm + from convlab2.dst import DST from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result @@ -18,9 +20,13 @@ def evaluate(model_dir, subtask, test_data, gt): # load weights, set eval() on default model = model_cls() pred = {} + bar = tqdm(total=sum(len(turns) for turns in test_data.values()), ncols=80, desc='evaluating') for dialog_id, turns in test_data.items(): model.init_session() - pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns] + for sys_utt, user_utt, gt_turn in turns: + pred[dialog_id] = [model.update_turn(sys_utt, user_utt)] + bar.update() + bar.close() result = eval_states(gt, pred, subtask) print(json.dumps(result, indent=4)) dump_result(model_dir, 'model-result.json', result)