Skip to content
Snippets Groups Projects
Unverified Commit a9ef624e authored by zhuqi's avatar zhuqi Committed by GitHub
Browse files

Merge pull request #96 from ConvLab/readme

fix bug in t5dst, t5nlg, and scgpt interface
parents 7d55eaae bfb44786
No related branches found
No related tags found
No related merge requests found
import logging
import os
import torch
from copy import deepcopy
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
from convlab.dst.dst import DST
from convlab.base_models.t5.dst.serialization import deserialize_dialogue_state
from convlab.util.custom_util import model_downloader
from convlab.util import load_ontology
class T5DST(DST):
def __init__(self, speaker, context_window_size, model_name_or_path, device='cuda'):
def __init__(self, dataset_name, speaker, context_window_size, model_name_or_path, device='cuda'):
assert speaker in ['user', 'system']
assert context_window_size > 0
self.ontology = load_ontology(dataset_name)
self.speaker = speaker
self.opponent = 'system' if speaker == 'user' else 'user'
self.context_window_size = context_window_size
......@@ -24,7 +25,12 @@ class T5DST(DST):
logging.info("T5DST loaded")
def update(self, context):
def update(self, user_action=None):
if self.state['history'][0][1] == 'null':
# skip first dummy turn
context = self.state['history'][1:]
else:
context = self.state['history']
if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1:
context = [item[1] for item in context]
context = context[-self.context_window_size:]
......@@ -37,7 +43,17 @@ class T5DST(DST):
output_seq = self.tokenizer.decode(output_seq[0], skip_special_tokens=True)
# print(output_seq)
state = deserialize_dialogue_state(output_seq.strip())
return state
self.state['belief_state'] = state
return self.state
def init_session(self):
self.state = dict()
self.state['belief_state'] = deepcopy(self.ontology['state'])
self.state['booked'] = dict()
self.state['history'] = []
self.state['system_action'] = []
self.state['user_action'] = []
self.state['terminated'] = False
if __name__ == '__main__':
......@@ -59,7 +75,9 @@ if __name__ == '__main__':
"You are welcome. Is there anything else I can help you with today?",
"No, I am all set. Have a nice day. Bye."],
]
dst = T5DST(speaker='user', context_window_size=100, model_name_or_path='output/dst/multiwoz21/user/context_100')
dst = T5DST('multiwoz21', speaker='user', context_window_size=100, model_name_or_path='ConvLab/t5-small-dst-multiwoz21')
dst.init_session()
for context in contexts:
print(dst.update(context))
dst.state['history'] = context
print(dst.update())
print()
import logging
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
from convlab.nlg.nlg import NLG
from convlab.base_models.t5.nlu.serialization import serialize_dialogue_acts
from convlab.util.custom_util import model_downloader
class T5NLG(NLG):
......@@ -33,7 +31,18 @@ class T5NLG(NLG):
else:
utts = ['']
input_seq = '\n'.join([f"{self.opponent if (i % 2) == (len(utts) % 2) else self.speaker}: {utt}" for i, utt in enumerate(utts)])
if isinstance(dialogue_acts, dict):
# da in unified format
dialogue_acts_seq = serialize_dialogue_acts(dialogue_acts)
elif isinstance(dialogue_acts[0], dict):
# da without da type
dialogue_acts_seq = serialize_dialogue_acts({'categorical': dialogue_acts})
elif isinstance(dialogue_acts[0], list):
# da is a list of list (convlab-2 format)
dialogue_acts_seq = serialize_dialogue_acts(
{'categorical': [{'intent': da[0], 'domain': da[1], 'slot': da[2], 'value': da[3]} for da in dialogue_acts]})
else:
raise ValueError(f"invalid dialog acts format {dialogue_acts}")
input_seq = dialogue_acts_seq + '\n' + input_seq
# print(input_seq)
input_seq = self.tokenizer(input_seq, return_tensors="pt").to(self.device)
......@@ -47,7 +56,7 @@ class T5NLG(NLG):
if __name__ == '__main__':
das = [
{
{ # da in unified format
"categorical": [],
"non-categorical": [],
"binary": [
......@@ -63,9 +72,7 @@ if __name__ == '__main__':
}
]
},
{
"categorical": [],
"non-categorical": [
[ # da without da type
{
"intent": "inform",
"domain": "taxi",
......@@ -83,25 +90,9 @@ if __name__ == '__main__':
"end": 78
}
],
"binary": [
{
"intent": "book",
"domain": "taxi",
"slot": ""
}
]
},
{
"categorical": [],
"non-categorical": [],
"binary": [
{
"intent": "reqmore",
"domain": "general",
"slot": ""
}
]
},
[ # da is a list of list (convlab-2 format)
["reqmore", "general", "", ""]
],
{
"categorical": [],
"non-categorical": [],
......@@ -132,7 +123,7 @@ if __name__ == '__main__':
"You are welcome. Is there anything else I can help you with today?"
"No, I am all set. Have a nice day. Bye."],
]
nlg = T5NLG(speaker='system', context_window_size=0, model_name_or_path='output/nlg/multiwoz21/system/context_3')
nlg = T5NLG(speaker='system', context_window_size=0, model_name_or_path='ConvLab/t5-small-nlg-multiwoz21')
for da, context in zip(das, contexts):
print(da)
print(nlg.generate(da, context))
......
import logging
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
from convlab.nlu.nlu import NLU
from convlab.base_models.t5.nlu.serialization import deserialize_dialogue_acts
from convlab.util.custom_util import model_downloader
class T5NLU(NLU):
......
......@@ -19,6 +19,17 @@ class SCGPT(NLG):
self.model.load_state_dict(torch.load(model_path))
def generate(self, action):
if isinstance(action, dict):
# da in unified format
pass
elif isinstance(action[0], dict):
# da without da type
action = {'categorical': action}
elif isinstance(action[0], list):
# da is a list of list (convlab-2 format)
action = {'categorical': [{'intent': da[0], 'domain': da[1], 'slot': da[2], 'value': da[3]} for da in action]}
else:
raise ValueError(f"invalid dialog acts format {action}")
action_str = act2str(action)
output = self._inference_batch([action_str])[0]
return output
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment