Skip to content
Snippets Groups Projects
Select Git revision
  • a9ef624ef3b404877d96b29a2932c8ca5b863f9c
  • master default protected
  • emoUS
  • add_default_vectorizer_and_pretrained_loading
  • clean_code
  • readme
  • issue127
  • generalized_action_dicts
  • ppo_num_dialogues
  • crossowoz_ddpt
  • issue_114
  • robust_masking_feature
  • scgpt_exp
  • e2e-soloist
  • convlab_exp
  • change_system_act_in_env
  • pre-training
  • nlg-scgpt
  • remapping_actions
  • soloist
20 results

scgpt.py

Blame
  • user avatar
    zqwerty authored
    17229c70
    History
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    scgpt.py 2.38 KiB
    import sys
    sys.path.append('../../..')
    
    import torch
    from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
    from torch.nn.parallel import DistributedDataParallel as DDP
    
    from convlab.nlg.nlg import NLG
    from util import act2str
    from scgpt_special_tokens import *
    
    
    class SCGPT(NLG):
        def __init__(self, dataset_name, model_path, device='cpu'):
            super(SCGPT, self).__init__()
            self.device = device
            self.model = GPT2LMHeadModel(config=GPT2Config.from_pretrained('gpt2-medium')).to(self.device)
            self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
            self.model.load_state_dict(torch.load(model_path))
    
        def generate(self, action):
            if isinstance(action, dict):
                # da in unified format
                pass
            elif isinstance(action[0], dict):
                # da without da type
                action = {'categorical': action}
            elif isinstance(action[0], list):
                # da is a list of list (convlab-2 format)
                action = {'categorical': [{'intent': da[0], 'domain': da[1], 'slot': da[2], 'value': da[3]} for da in action]}
            else:
                raise ValueError(f"invalid dialog acts format {action}")
            action_str = act2str(action)
            output = self._inference_batch([action_str])[0]
            return output
    
        def _inference_batch(self, sents):
            with torch.no_grad():
                sents = [sent for sent in sents]
                sent_ids = [self.tokenizer.encode(sent) + [self.tokenizer._convert_token_to_id_with_added_voc('&')] for sent in sents]
                max_len = max([len(sent) for sent in sent_ids])
                sent_ids = [sent + [0] * (max_len - len(sent))  for sent in sent_ids]
                inputs = torch.LongTensor(sent_ids).to(self.device)
                model_to_run = self.model.module if type(self.model) is DDP else self.model
                outputs = model_to_run.generate(inputs, max_length=256, attention_mask=(inputs!=0).float(),
                                                eos_token_id=self.tokenizer.pad_token_id)  # greedy
                outputs = outputs[:, len(inputs[0]):]
                def clean_sentence(sent):
                    sent = sent.strip()
                    if self.tokenizer.eos_token in sent:
                        sent = sent[:sent.index(self.tokenizer.eos_token)]
                    return sent
                output_strs = [clean_sentence(item) for item in outputs]
                return output_strs