Skip to content
Snippets Groups Projects
Commit 58824f13 authored by function2's avatar function2
Browse files

add progress bar

parent fc453658
No related branches found
No related tags found
No related merge requests found
...@@ -6,6 +6,8 @@ import os ...@@ -6,6 +6,8 @@ import os
import json import json
import importlib import importlib
from tqdm import tqdm
from convlab2.dst import DST from convlab2.dst import DST
from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result
...@@ -18,9 +20,13 @@ def evaluate(model_dir, subtask, test_data, gt): ...@@ -18,9 +20,13 @@ def evaluate(model_dir, subtask, test_data, gt):
# load weights, set eval() on default # load weights, set eval() on default
model = model_cls() model = model_cls()
pred = {} pred = {}
bar = tqdm(total=sum(len(turns) for turns in test_data.values()), ncols=80, desc='evaluating')
for dialog_id, turns in test_data.items(): for dialog_id, turns in test_data.items():
model.init_session() model.init_session()
pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns] for sys_utt, user_utt, gt_turn in turns:
pred[dialog_id] = [model.update_turn(sys_utt, user_utt)]
bar.update()
bar.close()
result = eval_states(gt, pred, subtask) result = eval_states(gt, pred, subtask)
print(json.dumps(result, indent=4)) print(json.dumps(result, indent=4))
dump_result(model_dir, 'model-result.json', result) dump_result(model_dir, 'model-result.json', result)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment