Skip to content
Snippets Groups Projects
Commit b770f3f0 authored by zqwerty's avatar zqwerty
Browse files

add t5dst interface, fix bug in t5nlu,t5nlg interface that ignore context_window_size

parent 5f2f6a44
No related branches found
No related tags found
No related merge requests found
import logging
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
from convlab2.dst.dst import DST
from convlab2.base_models.t5.dst.serialization import deserialize_dialogue_state
from convlab2.util.custom_util import model_downloader
class T5DST(DST):
def __init__(self, speaker, context_window_size, model_name_or_path, model_file=None, device='cuda'):
assert speaker in ['user', 'system']
assert context_window_size > 0
self.speaker = speaker
self.opponent = 'system' if speaker == 'user' else 'user'
self.context_window_size = context_window_size
model_dir = os.path.dirname(os.path.abspath(__file__))
if not os.path.exists(model_name_or_path):
model_downloader(model_dir, model_file)
self.config = AutoConfig.from_pretrained(model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config)
self.model.eval()
self.device = device if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
logging.info("T5DST loaded")
def update(self, context):
if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1:
context = [item[1] for item in context]
context = context[-self.context_window_size:]
input_seq = '\n'.join([f"{self.opponent if (i % 2) == (len(context) % 2) else self.speaker}: {utt}" for i, utt in enumerate(context)])
# print(input_seq)
input_seq = self.tokenizer(input_seq, return_tensors="pt").to(self.device)
# print(input_seq)
output_seq = self.model.generate(**input_seq, max_length=256)
# print(output_seq)
output_seq = self.tokenizer.decode(output_seq[0], skip_special_tokens=True)
# print(output_seq)
state = deserialize_dialogue_state(output_seq.strip())
return state
if __name__ == '__main__':
contexts = [
["I would like a taxi from Saint John's college to Pizza Hut Fen Ditton."],
["I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.",
"What time do you want to leave and what time do you want to arrive by?",
"I want to leave after 17:15."],
["I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.",
"What time do you want to leave and what time do you want to arrive by?",
"I want to leave after 17:15.",
"Booking completed! your taxi will be blue honda Contact number is 07218068540",
"Thank you for all the help! I appreciate it."],
["I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.",
"What time do you want to leave and what time do you want to arrive by?",
"I want to leave after 17:15.",
"Booking completed! your taxi will be blue honda Contact number is 07218068540",
"Thank you for all the help! I appreciate it.",
"You are welcome. Is there anything else I can help you with today?",
"No, I am all set. Have a nice day. Bye."],
]
dst = T5DST(speaker='user', context_window_size=100, model_name_or_path='output/dst/multiwoz21/user/context_100')
for context in contexts:
print(dst.update(context))
print()
...@@ -32,13 +32,14 @@ class T5NLG(NLG): ...@@ -32,13 +32,14 @@ class T5NLG(NLG):
if self.use_context: if self.use_context:
if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1: if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1:
context = [item[1] for item in context] context = [item[1] for item in context]
context = context[-self.context_window_size:]
utts = context + [''] utts = context + ['']
else: else:
utts = [''] utts = ['']
input_seq = '\n'.join([f"{self.opponent if (i % 2) == (len(utts) % 2) else self.speaker}: {utt}" for i, utt in enumerate(utts)]) input_seq = '\n'.join([f"{self.opponent if (i % 2) == (len(utts) % 2) else self.speaker}: {utt}" for i, utt in enumerate(utts)])
dialogue_acts_seq = serialize_dialogue_acts(dialogue_acts) dialogue_acts_seq = serialize_dialogue_acts(dialogue_acts)
input_seq = dialogue_acts_seq + '\n' + input_seq input_seq = dialogue_acts_seq + '\n' + input_seq
print(input_seq) # print(input_seq)
input_seq = self.tokenizer(input_seq, return_tensors="pt").to(self.device) input_seq = self.tokenizer(input_seq, return_tensors="pt").to(self.device)
# print(input_seq) # print(input_seq)
output_seq = self.model.generate(**input_seq, max_length=256) output_seq = self.model.generate(**input_seq, max_length=256)
...@@ -122,10 +123,16 @@ if __name__ == '__main__': ...@@ -122,10 +123,16 @@ if __name__ == '__main__':
["I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.", ["I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.",
"What time do you want to leave and what time do you want to arrive by?", "What time do you want to leave and what time do you want to arrive by?",
"I want to leave after 17:15."], "I want to leave after 17:15."],
["I want to leave after 17:15.", ["I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.",
"What time do you want to leave and what time do you want to arrive by?",
"I want to leave after 17:15.",
"Booking completed! your taxi will be blue honda Contact number is 07218068540", "Booking completed! your taxi will be blue honda Contact number is 07218068540",
"Thank you for all the help! I appreciate it."], "Thank you for all the help! I appreciate it."],
["Thank you for all the help! I appreciate it.", ["I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.",
"What time do you want to leave and what time do you want to arrive by?",
"I want to leave after 17:15.",
"Booking completed! your taxi will be blue honda Contact number is 07218068540",
"Thank you for all the help! I appreciate it.",
"You are welcome. Is there anything else I can help you with today?" "You are welcome. Is there anything else I can help you with today?"
"No, I am all set. Have a nice day. Bye."], "No, I am all set. Have a nice day. Bye."],
] ]
......
...@@ -32,6 +32,7 @@ class T5NLU(NLU): ...@@ -32,6 +32,7 @@ class T5NLU(NLU):
if self.use_context: if self.use_context:
if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1: if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1:
context = [item[1] for item in context] context = [item[1] for item in context]
context = context[-self.context_window_size:]
utts = context + [utterance] utts = context + [utterance]
else: else:
utts = [utterance] utts = [utterance]
...@@ -60,13 +61,15 @@ if __name__ == '__main__': ...@@ -60,13 +61,15 @@ if __name__ == '__main__':
[], [],
["I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.", ["I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.",
"What time do you want to leave and what time do you want to arrive by?"], "What time do you want to leave and what time do you want to arrive by?"],
["What time do you want to leave and what time do you want to arrive by?", ["I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.",
"What time do you want to leave and what time do you want to arrive by?",
"I want to leave after 17:15.", "I want to leave after 17:15.",
"Booking completed! your taxi will be blue honda Contact number is 07218068540"], "Booking completed! your taxi will be blue honda Contact number is 07218068540"],
[], [],
["Please find a restaurant called Nusha.", ["Please find a restaurant called Nusha.",
"I don't seem to be finding anything called Nusha. What type of food does the restaurant serve?"], "I don't seem to be finding anything called Nusha. What type of food does the restaurant serve?"],
["I don't seem to be finding anything called Nusha. What type of food does the restaurant serve?", ["Please find a restaurant called Nusha.",
"I don't seem to be finding anything called Nusha. What type of food does the restaurant serve?",
"I am not sure of the type of food but could you please check again and see if you can find it? Thank you.", "I am not sure of the type of food but could you please check again and see if you can find it? Thank you.",
"Could you double check that you've spelled the name correctly? The closest I can find is Nandos."] "Could you double check that you've spelled the name correctly? The closest I can find is Nandos."]
] ]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment