diff --git a/convlab2/nlg/evaluate.py b/convlab2/nlg/evaluate.py index e2cffc4553060c9c6c5e812c0ccd317981274fd6..1a4747b7f19a47f2c069e4ba286c9ad16763043b 100755 --- a/convlab2/nlg/evaluate.py +++ b/convlab2/nlg/evaluate.py @@ -8,6 +8,7 @@ import json import os import random import sys +import itertools import zipfile import numpy from numpy.lib.shape_base import _put_along_axis_dispatcher @@ -211,16 +212,18 @@ if __name__ == '__main__': numpy.random.seed(seed) torch.manual_seed(seed) - if len(sys.argv) != 4: + if len(sys.argv) < 4: print("usage:") print("\t python evaluate.py dataset model role") print("\t dataset=MultiWOZ, CrossWOZ, or Camrest") print("\t model=SCLSTM, SCLSTM_NoUNK, SCGPT or TemplateNLG") print("\t role=usr/sys") + print("\t [Optional] model_file") sys.exit() dataset_name = sys.argv[1] model_name = sys.argv[2] role = sys.argv[3] + model_file = sys.argv[4] if len(sys.argv) >= 5 else None if dataset_name == 'MultiWOZ': if model_name == 'SCLSTM': from convlab2.nlg.sclstm.multiwoz import SCLSTM @@ -242,17 +245,19 @@ if __name__ == '__main__': model = TemplateNLG(is_user=False) elif model_name == 'SCGPT': from convlab2.nlg.scgpt.multiwoz import SCGPT + if model_file is not None: + print(f"load model at {model_file}") if role == 'usr': - model = SCGPT(is_user=True) + model = SCGPT(model_file, is_user=True) elif role == 'sys': - model = SCGPT(is_user=False, model_file='scgpt/trained_output/multiwoz/') + model = SCGPT(model_file, is_user=False) else: raise Exception("Available models: SCLSTM, SCGPT, TEMPLATE") from convlab2.util.dataloader.module_dataloader import SingleTurnNLGDataloader from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader dataloader = SingleTurnNLGDataloader(dataset_dataloader=MultiWOZDataloader()) - data = dataloader.load_data(data_key='all', role=role)['test'] + data = dataloader.load_data(data_key='all', role=role, session_id=True)['test'] dialog_acts = [] golden_utts = [] @@ -262,7 +267,19 @@ if __name__ == '__main__': sen_num = 0 # sys.stdout = open(sys.argv[2] + '-' + sys.argv[3] + '-' + 'evaluate_logs_neo.txt','w') + assert 'utterance' in data and 'dialog_act' in data and 'session_id' in data + assert len(data['utterance']) == len(data['dialog_act']) == len(data['session_id']) + + # Turns during the same session should be contiguous, so we can call init_session at the first turn of a new session. + # This is necessary for SCGPT, but unnecessary for SCLSTM and TemplateNLG. + is_first_turn = [] + for _, iterator in itertools.groupby(data['session_id']): + is_first_turn.append(True) + next(iterator) + is_first_turn.extend(False for _ in iterator) for i in tqdm(range(len(data['utterance']))): + if is_first_turn[i]: + model.init_session() dialog_acts.append(data['dialog_act'][i]) golden_utts.append(data['utterance'][i]) gen_utts.append(model.generate(data['dialog_act'][i]))