Skip to content
Snippets Groups Projects
Commit 26efc1ed authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

wip

parent 7f9ee099
No related branches found
No related tags found
No related merge requests found
...@@ -46,7 +46,8 @@ class BERTNLU(NLU): ...@@ -46,7 +46,8 @@ class BERTNLU(NLU):
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
model_downloader(root_dir, model_file) model_downloader(root_dir, model_file)
model = JointBERT(config['model'], DEVICE, dataloader.tag_dim, dataloader.intent_dim) model = JointBERT(config['model'], DEVICE,
dataloader.tag_dim, dataloader.intent_dim)
state_dict = torch.load(os.path.join( state_dict = torch.load(os.path.join(
output_dir, 'pytorch_model.bin'), DEVICE) output_dir, 'pytorch_model.bin'), DEVICE)
...@@ -97,7 +98,8 @@ class BERTNLU(NLU): ...@@ -97,7 +98,8 @@ class BERTNLU(NLU):
intents = [] intents = []
da = {} da = {}
word_seq, tag_seq, new2ori = self.dataloader.bert_tokenize(ori_word_seq, ori_tag_seq) word_seq, tag_seq, new2ori = self.dataloader.bert_tokenize(
ori_word_seq, ori_tag_seq)
word_seq = word_seq[:510] word_seq = word_seq[:510]
tag_seq = tag_seq[:510] tag_seq = tag_seq[:510]
batch_data = [[ori_word_seq, ori_tag_seq, intents, da, context_seq, batch_data = [[ori_word_seq, ori_tag_seq, intents, da, context_seq,
......
...@@ -23,9 +23,10 @@ ...@@ -23,9 +23,10 @@
}, },
"nlu_sys": { "nlu_sys": {
"BertNLU": { "BertNLU": {
"class_path": "convlab.nlu.jointBERT.multiwoz.BERTNLU", "class_path": "convlab.nlu.jointBERT.unified_datasets.BERTNLU",
"ini_params": { "ini_params": {
"config_file": "multiwoz_all.json", "mode": "all",
"config_file": "multiwoz21_sys_context3.json",
"model_file": "https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_multiwoz_all.zip" "model_file": "https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_multiwoz_all.zip"
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment