From 58824f13964542f30b84f955947f1d8e622ba341 Mon Sep 17 00:00:00 2001 From: function2 <function2@qq.com> Date: Thu, 15 Oct 2020 02:12:48 +0800 Subject: [PATCH] add progress bar --- convlab2/dst/dstc9/eval_model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py index bad4ee7..aa42cc8 100644 --- a/convlab2/dst/dstc9/eval_model.py +++ b/convlab2/dst/dstc9/eval_model.py @@ -6,6 +6,8 @@ import os import json import importlib +from tqdm import tqdm + from convlab2.dst import DST from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result @@ -18,9 +20,13 @@ def evaluate(model_dir, subtask, test_data, gt): # load weights, set eval() on default model = model_cls() pred = {} + bar = tqdm(total=sum(len(turns) for turns in test_data.values()), ncols=80, desc='evaluating') for dialog_id, turns in test_data.items(): model.init_session() - pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns] + for sys_utt, user_utt, gt_turn in turns: + pred[dialog_id] = [model.update_turn(sys_utt, user_utt)] + bar.update() + bar.close() result = eval_states(gt, pred, subtask) print(json.dumps(result, indent=4)) dump_result(model_dir, 'model-result.json', result) -- GitLab