Skip to content
Snippets Groups Projects
Select Git revision
  • 88f2d26317e4331013d8779c37bb2e00f418333e
  • master default protected
  • emoUS
  • add_default_vectorizer_and_pretrained_loading
  • clean_code
  • readme
  • issue127
  • generalized_action_dicts
  • ppo_num_dialogues
  • crossowoz_ddpt
  • issue_114
  • robust_masking_feature
  • scgpt_exp
  • e2e-soloist
  • convlab_exp
  • change_system_act_in_env
  • pre-training
  • nlg-scgpt
  • remapping_actions
  • soloist
20 results

nlu.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    nlu.py 5.04 KiB
    import logging
    import os
    import json
    import torch
    from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer
    import transformers
    from convlab2.nlu.nlu import NLU
    from convlab2.nlu.jointBERT.dataloader import Dataloader
    from convlab2.nlu.jointBERT.jointBERT import JointBERT
    from convlab2.nlu.jointBERT.unified_datasets.postprocess import recover_intent
    from convlab2.util.custom_util import model_downloader
    
    
    class BERTNLU(NLU):
        def __init__(self, mode, config_file, model_file=None):
            assert mode == 'user' or mode == 'sys' or mode == 'all'
            self.mode = mode
            config_file = os.path.join(os.path.dirname(
                os.path.abspath(__file__)), 'configs/{}'.format(config_file))
            config = json.load(open(config_file))
            # print(config['DEVICE'])
            # DEVICE = config['DEVICE']
            DEVICE = 'cpu' if not torch.cuda.is_available() else config['DEVICE']
            root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
            data_dir = os.path.join(root_dir, config['data_dir'])
            output_dir = os.path.join(root_dir, config['output_dir'])
    
            assert os.path.exists(os.path.join(data_dir, 'intent_vocab.json')), print('Please run preprocess first')
    
            intent_vocab = json.load(open(os.path.join(data_dir, 'intent_vocab.json')))
            tag_vocab = json.load(open(os.path.join(data_dir, 'tag_vocab.json')))
            dataloader = Dataloader(intent_vocab=intent_vocab, tag_vocab=tag_vocab,
                                    pretrained_weights=config['model']['pretrained_weights'])
    
            logging.info('intent num:' +  str(len(intent_vocab)))
            logging.info('tag num:' + str(len(tag_vocab)))
    
            if not os.path.exists(output_dir):
                model_downloader(root_dir, model_file)
            model = JointBERT(config['model'], DEVICE, dataloader.tag_dim, dataloader.intent_dim)
    
            state_dict = torch.load(os.path.join(output_dir, 'pytorch_model.bin'), DEVICE)
            if int(transformers.__version__.split('.')[0]) >= 3 and 'bert.embeddings.position_ids' not in state_dict:
                state_dict['bert.embeddings.position_ids'] = torch.tensor(range(512)).reshape(1, -1).to(DEVICE)
    
            model.load_state_dict(state_dict)
            model.to(DEVICE)
            model.eval()
    
            self.model = model
            self.use_context = config['model']['context']
            self.context_window_size = config['context_window_size']
            self.dataloader = dataloader
            self.sent_tokenizer = PunktSentenceTokenizer()
            self.word_tokenizer = TreebankWordTokenizer()
            logging.info("BERTNLU loaded")
    
        def predict(self, utterance, context=list()):
            sentences = self.sent_tokenizer.tokenize(utterance)
            ori_word_seq = [token for sent in sentences for token in self.word_tokenizer.tokenize(sent)]
            ori_tag_seq = [str(('O',))] * len(ori_word_seq)
            if self.use_context:
                if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1:
                    context = [item[1] for item in context]
                context_seq = self.dataloader.tokenizer.encode(' [SEP] '.join(context[-self.context_window_size:]))
                context_seq = context_seq[:510]
            else:
                context_seq = self.dataloader.tokenizer.encode('')
            intents = []
            da = {}
    
            word_seq, tag_seq, new2ori = self.dataloader.bert_tokenize(ori_word_seq, ori_tag_seq)
            word_seq = word_seq[:510]
            tag_seq = tag_seq[:510]
            batch_data = [[ori_word_seq, ori_tag_seq, intents, da, context_seq,
                           new2ori, word_seq, self.dataloader.seq_tag2id(tag_seq), self.dataloader.seq_intent2id(intents)]]
    
            pad_batch = self.dataloader.pad_batch(batch_data)
            pad_batch = tuple(t.to(self.model.device) for t in pad_batch)
            word_seq_tensor, tag_seq_tensor, intent_tensor, word_mask_tensor, tag_mask_tensor, context_seq_tensor, context_mask_tensor = pad_batch
            slot_logits, intent_logits = self.model.forward(word_seq_tensor, word_mask_tensor,
                                                            context_seq_tensor=context_seq_tensor,
                                                            context_mask_tensor=context_mask_tensor)
            das = recover_intent(self.dataloader, intent_logits[0], slot_logits[0], tag_mask_tensor[0],
                                 batch_data[0][0], batch_data[0][-4])
            dialog_act = []
            for da_type in das:
                for da in das[da_type]:
                    dialog_act.append([da['intent'], da['domain'], da['slot'], da.get('value','')])
            return dialog_act
    
    
    if __name__ == '__main__':
        texts = [
            "I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.",
            "I want to leave after 17:15.",
            "Thank you for all the help! I appreciate it.",
            "Please find a restaurant called Nusha.",
            "What is the train id, please? ",
            "I don't care about the price and it doesn't need to have free parking."
        ]
        nlu = BERTNLU(mode='user', config_file='multiwoz21_user.json')
        for text in texts:
            print(text)
            print(nlu.predict(text))
            print()