diff --git a/convlab/base_models/t5/dst/dst.py b/convlab/base_models/t5/dst/dst.py index 3c5ec2525091cfdcb423460ed1b01871087deb21..0fbc1cb19b23421229a8cacadf333b44fc8ebe58 100755 --- a/convlab/base_models/t5/dst/dst.py +++ b/convlab/base_models/t5/dst/dst.py @@ -1,16 +1,17 @@ import logging -import os import torch +from copy import deepcopy from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig from convlab.dst.dst import DST from convlab.base_models.t5.dst.serialization import deserialize_dialogue_state -from convlab.util.custom_util import model_downloader +from convlab.util import load_ontology class T5DST(DST): - def __init__(self, speaker, context_window_size, model_name_or_path, device='cuda'): + def __init__(self, dataset_name, speaker, context_window_size, model_name_or_path, device='cuda'): assert speaker in ['user', 'system'] assert context_window_size > 0 + self.ontology = load_ontology(dataset_name) self.speaker = speaker self.opponent = 'system' if speaker == 'user' else 'user' self.context_window_size = context_window_size @@ -24,7 +25,12 @@ class T5DST(DST): logging.info("T5DST loaded") - def update(self, context): + def update(self, user_action=None): + if self.state['history'][0][1] == 'null': + # skip first dummy turn + context = self.state['history'][1:] + else: + context = self.state['history'] if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1: context = [item[1] for item in context] context = context[-self.context_window_size:] @@ -37,7 +43,17 @@ class T5DST(DST): output_seq = self.tokenizer.decode(output_seq[0], skip_special_tokens=True) # print(output_seq) state = deserialize_dialogue_state(output_seq.strip()) - return state + self.state['belief_state'] = state + return self.state + + def init_session(self): + self.state = dict() + self.state['belief_state'] = deepcopy(self.ontology['state']) + self.state['booked'] = dict() + self.state['history'] = [] + self.state['system_action'] = [] + self.state['user_action'] = [] + self.state['terminated'] = False if __name__ == '__main__': @@ -59,7 +75,9 @@ if __name__ == '__main__': "You are welcome. Is there anything else I can help you with today?", "No, I am all set. Have a nice day. Bye."], ] - dst = T5DST(speaker='user', context_window_size=100, model_name_or_path='output/dst/multiwoz21/user/context_100') + dst = T5DST('multiwoz21', speaker='user', context_window_size=100, model_name_or_path='ConvLab/t5-small-dst-multiwoz21') + dst.init_session() for context in contexts: - print(dst.update(context)) + dst.state['history'] = context + print(dst.update()) print() diff --git a/convlab/base_models/t5/nlg/nlg.py b/convlab/base_models/t5/nlg/nlg.py index 214dc01eed75cfbb85a740e8f5fee8a759d813b0..d3413d7c48c6099aa3ab46baf05f89621a71732e 100755 --- a/convlab/base_models/t5/nlg/nlg.py +++ b/convlab/base_models/t5/nlg/nlg.py @@ -1,10 +1,8 @@ import logging -import os import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig from convlab.nlg.nlg import NLG from convlab.base_models.t5.nlu.serialization import serialize_dialogue_acts -from convlab.util.custom_util import model_downloader class T5NLG(NLG): @@ -33,7 +31,18 @@ class T5NLG(NLG): else: utts = [''] input_seq = '\n'.join([f"{self.opponent if (i % 2) == (len(utts) % 2) else self.speaker}: {utt}" for i, utt in enumerate(utts)]) - dialogue_acts_seq = serialize_dialogue_acts(dialogue_acts) + if isinstance(dialogue_acts, dict): + # da in unified format + dialogue_acts_seq = serialize_dialogue_acts(dialogue_acts) + elif isinstance(dialogue_acts[0], dict): + # da without da type + dialogue_acts_seq = serialize_dialogue_acts({'categorical': dialogue_acts}) + elif isinstance(dialogue_acts[0], list): + # da is a list of list (convlab-2 format) + dialogue_acts_seq = serialize_dialogue_acts( + {'categorical': [{'intent': da[0], 'domain': da[1], 'slot': da[2], 'value': da[3]} for da in dialogue_acts]}) + else: + raise ValueError(f"invalid dialog acts format {dialogue_acts}") input_seq = dialogue_acts_seq + '\n' + input_seq # print(input_seq) input_seq = self.tokenizer(input_seq, return_tensors="pt").to(self.device) @@ -47,7 +56,7 @@ class T5NLG(NLG): if __name__ == '__main__': das = [ - { + { # da in unified format "categorical": [], "non-categorical": [], "binary": [ @@ -63,9 +72,7 @@ if __name__ == '__main__': } ] }, - { - "categorical": [], - "non-categorical": [ + [ # da without da type { "intent": "inform", "domain": "taxi", @@ -83,25 +90,9 @@ if __name__ == '__main__': "end": 78 } ], - "binary": [ - { - "intent": "book", - "domain": "taxi", - "slot": "" - } - ] - }, - { - "categorical": [], - "non-categorical": [], - "binary": [ - { - "intent": "reqmore", - "domain": "general", - "slot": "" - } - ] - }, + [ # da is a list of list (convlab-2 format) + ["reqmore", "general", "", ""] + ], { "categorical": [], "non-categorical": [], @@ -132,7 +123,7 @@ if __name__ == '__main__': "You are welcome. Is there anything else I can help you with today?" "No, I am all set. Have a nice day. Bye."], ] - nlg = T5NLG(speaker='system', context_window_size=0, model_name_or_path='output/nlg/multiwoz21/system/context_3') + nlg = T5NLG(speaker='system', context_window_size=0, model_name_or_path='ConvLab/t5-small-nlg-multiwoz21') for da, context in zip(das, contexts): print(da) print(nlg.generate(da, context)) diff --git a/convlab/base_models/t5/nlu/nlu.py b/convlab/base_models/t5/nlu/nlu.py index a5a6e6a23ec184b15fc073d88fa1a6b3fece34d8..4162aa3fd0ea6f4fbbde2620d45b9fb73e104875 100755 --- a/convlab/base_models/t5/nlu/nlu.py +++ b/convlab/base_models/t5/nlu/nlu.py @@ -1,10 +1,8 @@ import logging -import os import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig from convlab.nlu.nlu import NLU from convlab.base_models.t5.nlu.serialization import deserialize_dialogue_acts -from convlab.util.custom_util import model_downloader class T5NLU(NLU): diff --git a/convlab/nlg/scgpt/scgpt.py b/convlab/nlg/scgpt/scgpt.py index 4977b4f3778229293bb60c2e98b7e715ebe227dd..ee591e79ac0590582ec6a84be62bb19e31b31004 100644 --- a/convlab/nlg/scgpt/scgpt.py +++ b/convlab/nlg/scgpt/scgpt.py @@ -19,6 +19,17 @@ class SCGPT(NLG): self.model.load_state_dict(torch.load(model_path)) def generate(self, action): + if isinstance(action, dict): + # da in unified format + pass + elif isinstance(action[0], dict): + # da without da type + action = {'categorical': action} + elif isinstance(action[0], list): + # da is a list of list (convlab-2 format) + action = {'categorical': [{'intent': da[0], 'domain': da[1], 'slot': da[2], 'value': da[3]} for da in action]} + else: + raise ValueError(f"invalid dialog acts format {action}") action_str = act2str(action) output = self._inference_batch([action_str])[0] return output