From 58824f13964542f30b84f955947f1d8e622ba341 Mon Sep 17 00:00:00 2001
From: function2 <function2@qq.com>
Date: Thu, 15 Oct 2020 02:12:48 +0800
Subject: [PATCH] add progress bar

---
 convlab2/dst/dstc9/eval_model.py | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py
index bad4ee7..aa42cc8 100644
--- a/convlab2/dst/dstc9/eval_model.py
+++ b/convlab2/dst/dstc9/eval_model.py
@@ -6,6 +6,8 @@ import os
 import json
 import importlib
 
+from tqdm import tqdm
+
 from convlab2.dst import DST
 from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result
 
@@ -18,9 +20,13 @@ def evaluate(model_dir, subtask, test_data, gt):
     # load weights, set eval() on default
     model = model_cls()
     pred = {}
+    bar = tqdm(total=sum(len(turns) for turns in test_data.values()), ncols=80, desc='evaluating')
     for dialog_id, turns in test_data.items():
         model.init_session()
-        pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns]
+        for sys_utt, user_utt, gt_turn in turns:
+            pred[dialog_id] = [model.update_turn(sys_utt, user_utt)]
+            bar.update()
+    bar.close()
     result = eval_states(gt, pred, subtask)
     print(json.dumps(result, indent=4))
     dump_result(model_dir, 'model-result.json', result)
-- 
GitLab