From 58f9a0d4586bf97cf9195c7a3b5ad8eb4c713d90 Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Tue, 13 Dec 2022 12:07:05 +0800 Subject: [PATCH] update nlu interface: auto preprocess for jointbert; transform da output for t5nlu --- convlab/base_models/t5/nlu/nlu.py | 11 +++++++---- convlab/nlu/jointBERT/unified_datasets/nlu.py | 5 ++++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/convlab/base_models/t5/nlu/nlu.py b/convlab/base_models/t5/nlu/nlu.py index 4162aa3f..ef679ffa 100755 --- a/convlab/base_models/t5/nlu/nlu.py +++ b/convlab/base_models/t5/nlu/nlu.py @@ -1,4 +1,4 @@ -import logging +gimport logging import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig from convlab.nlu.nlu import NLU @@ -38,8 +38,11 @@ class T5NLU(NLU): # print(output_seq) output_seq = self.tokenizer.decode(output_seq[0], skip_special_tokens=True) # print(output_seq) - dialogue_acts = deserialize_dialogue_acts(output_seq.strip()) - return dialogue_acts + das = deserialize_dialogue_acts(output_seq.strip()) + dialog_act = [] + for da in das: + dialog_act.append([da['intent'], da['domain'], da['slot'], da.get('value','')]) + return dialog_act if __name__ == '__main__': @@ -67,7 +70,7 @@ if __name__ == '__main__': "I am not sure of the type of food but could you please check again and see if you can find it? Thank you.", "Could you double check that you've spelled the name correctly? The closest I can find is Nandos."] ] - nlu = T5NLU(speaker='user', context_window_size=3, model_name_or_path='output/nlu/multiwoz21/user/context_3') + nlu = T5NLU(speaker='user', context_window_size=0, model_name_or_path='ConvLab/t5-small-nlu-multiwoz21') for text, context in zip(texts, contexts): print(text) print(nlu.predict(text, context)) diff --git a/convlab/nlu/jointBERT/unified_datasets/nlu.py b/convlab/nlu/jointBERT/unified_datasets/nlu.py index 85cf56d9..311f0b6f 100755 --- a/convlab/nlu/jointBERT/unified_datasets/nlu.py +++ b/convlab/nlu/jointBERT/unified_datasets/nlu.py @@ -7,6 +7,7 @@ import transformers from convlab.nlu.nlu import NLU from convlab.nlu.jointBERT.dataloader import Dataloader from convlab.nlu.jointBERT.jointBERT import JointBERT +from convlab.nlu.jointBERT.unified_datasets.preprocess import preprocess from convlab.nlu.jointBERT.unified_datasets.postprocess import recover_intent from convlab.util.custom_util import model_downloader @@ -25,7 +26,9 @@ class BERTNLU(NLU): data_dir = os.path.join(root_dir, config['data_dir']) output_dir = os.path.join(root_dir, config['output_dir']) - assert os.path.exists(os.path.join(data_dir, 'intent_vocab.json')), print('Please run preprocess first') + if not os.path.exists(os.path.join(data_dir, 'intent_vocab.json')): + print('Run preprocess first') + preprocess(config['dataset_name'], data_dir.split('/')[-2], os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data'), int(data_dir.split('/')[-1].split('_')[-1])) intent_vocab = json.load(open(os.path.join(data_dir, 'intent_vocab.json'))) tag_vocab = json.load(open(os.path.join(data_dir, 'tag_vocab.json'))) -- GitLab