diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py index aa42cc84ea82bc6635bd7348406cb1b42d2ca799..207eac70e6f28e4b9bcab0e48cf284afba9a8e4a 100644 --- a/convlab2/dst/dstc9/eval_model.py +++ b/convlab2/dst/dstc9/eval_model.py @@ -23,8 +23,9 @@ def evaluate(model_dir, subtask, test_data, gt): bar = tqdm(total=sum(len(turns) for turns in test_data.values()), ncols=80, desc='evaluating') for dialog_id, turns in test_data.items(): model.init_session() + pred[dialog_id] = [] for sys_utt, user_utt, gt_turn in turns: - pred[dialog_id] = [model.update_turn(sys_utt, user_utt)] + pred[dialog_id].append(model.update_turn(sys_utt, user_utt)) bar.update() bar.close() result = eval_states(gt, pred, subtask)