Skip to content
Snippets Groups Projects
Commit aac882a3 authored by aaa123git's avatar aaa123git
Browse files

call init_session at the first turn of a new session

parent 27f2a2b3
Branches
No related tags found
No related merge requests found
...@@ -8,6 +8,7 @@ import json ...@@ -8,6 +8,7 @@ import json
import os import os
import random import random
import sys import sys
import itertools
import zipfile import zipfile
import numpy import numpy
from numpy.lib.shape_base import _put_along_axis_dispatcher from numpy.lib.shape_base import _put_along_axis_dispatcher
...@@ -211,16 +212,18 @@ if __name__ == '__main__': ...@@ -211,16 +212,18 @@ if __name__ == '__main__':
numpy.random.seed(seed) numpy.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if len(sys.argv) != 4: if len(sys.argv) < 4:
print("usage:") print("usage:")
print("\t python evaluate.py dataset model role") print("\t python evaluate.py dataset model role")
print("\t dataset=MultiWOZ, CrossWOZ, or Camrest") print("\t dataset=MultiWOZ, CrossWOZ, or Camrest")
print("\t model=SCLSTM, SCLSTM_NoUNK, SCGPT or TemplateNLG") print("\t model=SCLSTM, SCLSTM_NoUNK, SCGPT or TemplateNLG")
print("\t role=usr/sys") print("\t role=usr/sys")
print("\t [Optional] model_file")
sys.exit() sys.exit()
dataset_name = sys.argv[1] dataset_name = sys.argv[1]
model_name = sys.argv[2] model_name = sys.argv[2]
role = sys.argv[3] role = sys.argv[3]
model_file = sys.argv[4] if len(sys.argv) >= 5 else None
if dataset_name == 'MultiWOZ': if dataset_name == 'MultiWOZ':
if model_name == 'SCLSTM': if model_name == 'SCLSTM':
from convlab2.nlg.sclstm.multiwoz import SCLSTM from convlab2.nlg.sclstm.multiwoz import SCLSTM
...@@ -242,17 +245,19 @@ if __name__ == '__main__': ...@@ -242,17 +245,19 @@ if __name__ == '__main__':
model = TemplateNLG(is_user=False) model = TemplateNLG(is_user=False)
elif model_name == 'SCGPT': elif model_name == 'SCGPT':
from convlab2.nlg.scgpt.multiwoz import SCGPT from convlab2.nlg.scgpt.multiwoz import SCGPT
if model_file is not None:
print(f"load model at {model_file}")
if role == 'usr': if role == 'usr':
model = SCGPT(is_user=True) model = SCGPT(model_file, is_user=True)
elif role == 'sys': elif role == 'sys':
model = SCGPT(is_user=False, model_file='scgpt/trained_output/multiwoz/') model = SCGPT(model_file, is_user=False)
else: else:
raise Exception("Available models: SCLSTM, SCGPT, TEMPLATE") raise Exception("Available models: SCLSTM, SCGPT, TEMPLATE")
from convlab2.util.dataloader.module_dataloader import SingleTurnNLGDataloader from convlab2.util.dataloader.module_dataloader import SingleTurnNLGDataloader
from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
dataloader = SingleTurnNLGDataloader(dataset_dataloader=MultiWOZDataloader()) dataloader = SingleTurnNLGDataloader(dataset_dataloader=MultiWOZDataloader())
data = dataloader.load_data(data_key='all', role=role)['test'] data = dataloader.load_data(data_key='all', role=role, session_id=True)['test']
dialog_acts = [] dialog_acts = []
golden_utts = [] golden_utts = []
...@@ -262,7 +267,19 @@ if __name__ == '__main__': ...@@ -262,7 +267,19 @@ if __name__ == '__main__':
sen_num = 0 sen_num = 0
# sys.stdout = open(sys.argv[2] + '-' + sys.argv[3] + '-' + 'evaluate_logs_neo.txt','w') # sys.stdout = open(sys.argv[2] + '-' + sys.argv[3] + '-' + 'evaluate_logs_neo.txt','w')
assert 'utterance' in data and 'dialog_act' in data and 'session_id' in data
assert len(data['utterance']) == len(data['dialog_act']) == len(data['session_id'])
# Turns during the same session should be contiguous, so we can call init_session at the first turn of a new session.
# This is necessary for SCGPT, but unnecessary for SCLSTM and TemplateNLG.
is_first_turn = []
for _, iterator in itertools.groupby(data['session_id']):
is_first_turn.append(True)
next(iterator)
is_first_turn.extend(False for _ in iterator)
for i in tqdm(range(len(data['utterance']))): for i in tqdm(range(len(data['utterance']))):
if is_first_turn[i]:
model.init_session()
dialog_acts.append(data['dialog_act'][i]) dialog_acts.append(data['dialog_act'][i])
golden_utts.append(data['utterance'][i]) golden_utts.append(data['utterance'][i])
gen_utts.append(model.generate(data['dialog_act'][i])) gen_utts.append(model.generate(data['dialog_act'][i]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment