From 26efc1edd7062450fb9ba25d613427a28601c938 Mon Sep 17 00:00:00 2001 From: Hsien-Chin Lin <linh@hhu.de> Date: Thu, 19 Jan 2023 15:56:24 +0100 Subject: [PATCH] wip --- convlab/nlu/jointBERT/multiwoz/nlu.py | 10 ++++++---- .../policy/ppo/emoTUS-BertNLU-RuleDST-PPOPolicy.json | 5 +++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/convlab/nlu/jointBERT/multiwoz/nlu.py b/convlab/nlu/jointBERT/multiwoz/nlu.py index 1373919e..7f707766 100755 --- a/convlab/nlu/jointBERT/multiwoz/nlu.py +++ b/convlab/nlu/jointBERT/multiwoz/nlu.py @@ -41,12 +41,13 @@ class BERTNLU(NLU): dataloader = Dataloader(intent_vocab=intent_vocab, tag_vocab=tag_vocab, pretrained_weights=config['model']['pretrained_weights']) - logging.info('intent num:' + str(len(intent_vocab))) + logging.info('intent num:' + str(len(intent_vocab))) logging.info('tag num:' + str(len(tag_vocab))) if not os.path.exists(output_dir): model_downloader(root_dir, model_file) - model = JointBERT(config['model'], DEVICE, dataloader.tag_dim, dataloader.intent_dim) + model = JointBERT(config['model'], DEVICE, + dataloader.tag_dim, dataloader.intent_dim) state_dict = torch.load(os.path.join( output_dir, 'pytorch_model.bin'), DEVICE) @@ -74,7 +75,7 @@ class BERTNLU(NLU): for token in token_list: token = token.strip() self.nlp.tokenizer.add_special_case( - #token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}]) + # token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}]) token, [{ORTH: token}]) logging.info("BERTNLU loaded") @@ -97,7 +98,8 @@ class BERTNLU(NLU): intents = [] da = {} - word_seq, tag_seq, new2ori = self.dataloader.bert_tokenize(ori_word_seq, ori_tag_seq) + word_seq, tag_seq, new2ori = self.dataloader.bert_tokenize( + ori_word_seq, ori_tag_seq) word_seq = word_seq[:510] tag_seq = tag_seq[:510] batch_data = [[ori_word_seq, ori_tag_seq, intents, da, context_seq, diff --git a/convlab/policy/ppo/emoTUS-BertNLU-RuleDST-PPOPolicy.json b/convlab/policy/ppo/emoTUS-BertNLU-RuleDST-PPOPolicy.json index 9cdbd8bb..57dcd4de 100644 --- a/convlab/policy/ppo/emoTUS-BertNLU-RuleDST-PPOPolicy.json +++ b/convlab/policy/ppo/emoTUS-BertNLU-RuleDST-PPOPolicy.json @@ -23,9 +23,10 @@ }, "nlu_sys": { "BertNLU": { - "class_path": "convlab.nlu.jointBERT.multiwoz.BERTNLU", + "class_path": "convlab.nlu.jointBERT.unified_datasets.BERTNLU", "ini_params": { - "config_file": "multiwoz_all.json", + "mode": "all", + "config_file": "multiwoz21_sys_context3.json", "model_file": "https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_multiwoz_all.zip" } } -- GitLab