Skip to content
Snippets Groups Projects
Unverified Commit 8cabb843 authored by zhuqi's avatar zhuqi Committed by GitHub
Browse files

Merge pull request #104 from ConvLab/readme

update nlu interface: auto preprocess for jointbert; transform da output for t5nlu
parents 3154fc7d 58f9a0d4
No related branches found
No related tags found
No related merge requests found
import logging gimport logging
import torch import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
from convlab.nlu.nlu import NLU from convlab.nlu.nlu import NLU
...@@ -38,8 +38,11 @@ class T5NLU(NLU): ...@@ -38,8 +38,11 @@ class T5NLU(NLU):
# print(output_seq) # print(output_seq)
output_seq = self.tokenizer.decode(output_seq[0], skip_special_tokens=True) output_seq = self.tokenizer.decode(output_seq[0], skip_special_tokens=True)
# print(output_seq) # print(output_seq)
dialogue_acts = deserialize_dialogue_acts(output_seq.strip()) das = deserialize_dialogue_acts(output_seq.strip())
return dialogue_acts dialog_act = []
for da in das:
dialog_act.append([da['intent'], da['domain'], da['slot'], da.get('value','')])
return dialog_act
if __name__ == '__main__': if __name__ == '__main__':
...@@ -67,7 +70,7 @@ if __name__ == '__main__': ...@@ -67,7 +70,7 @@ if __name__ == '__main__':
"I am not sure of the type of food but could you please check again and see if you can find it? Thank you.", "I am not sure of the type of food but could you please check again and see if you can find it? Thank you.",
"Could you double check that you've spelled the name correctly? The closest I can find is Nandos."] "Could you double check that you've spelled the name correctly? The closest I can find is Nandos."]
] ]
nlu = T5NLU(speaker='user', context_window_size=3, model_name_or_path='output/nlu/multiwoz21/user/context_3') nlu = T5NLU(speaker='user', context_window_size=0, model_name_or_path='ConvLab/t5-small-nlu-multiwoz21')
for text, context in zip(texts, contexts): for text, context in zip(texts, contexts):
print(text) print(text)
print(nlu.predict(text, context)) print(nlu.predict(text, context))
......
...@@ -7,6 +7,7 @@ import transformers ...@@ -7,6 +7,7 @@ import transformers
from convlab.nlu.nlu import NLU from convlab.nlu.nlu import NLU
from convlab.nlu.jointBERT.dataloader import Dataloader from convlab.nlu.jointBERT.dataloader import Dataloader
from convlab.nlu.jointBERT.jointBERT import JointBERT from convlab.nlu.jointBERT.jointBERT import JointBERT
from convlab.nlu.jointBERT.unified_datasets.preprocess import preprocess
from convlab.nlu.jointBERT.unified_datasets.postprocess import recover_intent from convlab.nlu.jointBERT.unified_datasets.postprocess import recover_intent
from convlab.util.custom_util import model_downloader from convlab.util.custom_util import model_downloader
...@@ -25,7 +26,9 @@ class BERTNLU(NLU): ...@@ -25,7 +26,9 @@ class BERTNLU(NLU):
data_dir = os.path.join(root_dir, config['data_dir']) data_dir = os.path.join(root_dir, config['data_dir'])
output_dir = os.path.join(root_dir, config['output_dir']) output_dir = os.path.join(root_dir, config['output_dir'])
assert os.path.exists(os.path.join(data_dir, 'intent_vocab.json')), print('Please run preprocess first') if not os.path.exists(os.path.join(data_dir, 'intent_vocab.json')):
print('Run preprocess first')
preprocess(config['dataset_name'], data_dir.split('/')[-2], os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data'), int(data_dir.split('/')[-1].split('_')[-1]))
intent_vocab = json.load(open(os.path.join(data_dir, 'intent_vocab.json'))) intent_vocab = json.load(open(os.path.join(data_dir, 'intent_vocab.json')))
tag_vocab = json.load(open(os.path.join(data_dir, 'tag_vocab.json'))) tag_vocab = json.load(open(os.path.join(data_dir, 'tag_vocab.json')))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment