diff --git a/convlab/base_models/t5/nlu/nlu.py b/convlab/base_models/t5/nlu/nlu.py index 4162aa3fd0ea6f4fbbde2620d45b9fb73e104875..ef679ffa9df6fa03f70cc82c98f64f910543b400 100755 --- a/convlab/base_models/t5/nlu/nlu.py +++ b/convlab/base_models/t5/nlu/nlu.py @@ -1,4 +1,4 @@ -import logging +gimport logging import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig from convlab.nlu.nlu import NLU @@ -38,8 +38,11 @@ class T5NLU(NLU): # print(output_seq) output_seq = self.tokenizer.decode(output_seq[0], skip_special_tokens=True) # print(output_seq) - dialogue_acts = deserialize_dialogue_acts(output_seq.strip()) - return dialogue_acts + das = deserialize_dialogue_acts(output_seq.strip()) + dialog_act = [] + for da in das: + dialog_act.append([da['intent'], da['domain'], da['slot'], da.get('value','')]) + return dialog_act if __name__ == '__main__': @@ -67,7 +70,7 @@ if __name__ == '__main__': "I am not sure of the type of food but could you please check again and see if you can find it? Thank you.", "Could you double check that you've spelled the name correctly? The closest I can find is Nandos."] ] - nlu = T5NLU(speaker='user', context_window_size=3, model_name_or_path='output/nlu/multiwoz21/user/context_3') + nlu = T5NLU(speaker='user', context_window_size=0, model_name_or_path='ConvLab/t5-small-nlu-multiwoz21') for text, context in zip(texts, contexts): print(text) print(nlu.predict(text, context)) diff --git a/convlab/nlu/jointBERT/unified_datasets/nlu.py b/convlab/nlu/jointBERT/unified_datasets/nlu.py index 85cf56d94575317f668ea0b76e2d91ed23a572bb..311f0b6f8fe09c41ac0e38332e7ea467a621bb17 100755 --- a/convlab/nlu/jointBERT/unified_datasets/nlu.py +++ b/convlab/nlu/jointBERT/unified_datasets/nlu.py @@ -7,6 +7,7 @@ import transformers from convlab.nlu.nlu import NLU from convlab.nlu.jointBERT.dataloader import Dataloader from convlab.nlu.jointBERT.jointBERT import JointBERT +from convlab.nlu.jointBERT.unified_datasets.preprocess import preprocess from convlab.nlu.jointBERT.unified_datasets.postprocess import recover_intent from convlab.util.custom_util import model_downloader @@ -25,7 +26,9 @@ class BERTNLU(NLU): data_dir = os.path.join(root_dir, config['data_dir']) output_dir = os.path.join(root_dir, config['output_dir']) - assert os.path.exists(os.path.join(data_dir, 'intent_vocab.json')), print('Please run preprocess first') + if not os.path.exists(os.path.join(data_dir, 'intent_vocab.json')): + print('Run preprocess first') + preprocess(config['dataset_name'], data_dir.split('/')[-2], os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data'), int(data_dir.split('/')[-1].split('_')[-1])) intent_vocab = json.load(open(os.path.join(data_dir, 'intent_vocab.json'))) tag_vocab = json.load(open(os.path.join(data_dir, 'tag_vocab.json')))