diff --git a/convlab2/nlu/jointBERT/multiwoz/nlu.py b/convlab2/nlu/jointBERT/multiwoz/nlu.py index 2bc7ec61dfd89d8255b3c01dfa6e20939a577e44..17c24d98196eee90523bb8d40d3efc44851f8c5a 100755 --- a/convlab2/nlu/jointBERT/multiwoz/nlu.py +++ b/convlab2/nlu/jointBERT/multiwoz/nlu.py @@ -66,15 +66,15 @@ class BERTNLU(NLU): if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1: context = [item[1] for item in context] context_seq = self.dataloader.tokenizer.encode('[CLS] ' + ' [SEP] '.join(context[-3:])) - context_seq = context_seq[:self.dataloader.tokenizer.max_model_input_sizes] + context_seq = context_seq[:512] else: context_seq = self.dataloader.tokenizer.encode('[CLS]') intents = [] da = {} word_seq, tag_seq, new2ori = self.dataloader.bert_tokenize(ori_word_seq, ori_tag_seq) - word_seq = word_seq[:self.dataloader.tokenizer.max_model_input_sizes] - tag_seq = tag_seq[:self.dataloader.tokenizer.max_model_input_sizes] + word_seq = word_seq[:512] + tag_seq = tag_seq[:512] batch_data = [[ori_word_seq, ori_tag_seq, intents, da, context_seq, new2ori, word_seq, self.dataloader.seq_tag2id(tag_seq), self.dataloader.seq_intent2id(intents)]]