diff --git a/convlab/nlg/scgpt/scgpt.py b/convlab/nlg/scgpt/scgpt.py index ee591e79ac0590582ec6a84be62bb19e31b31004..def3b2f3ae1f5a750049dc6af23e524a77789913 100644 --- a/convlab/nlg/scgpt/scgpt.py +++ b/convlab/nlg/scgpt/scgpt.py @@ -1,3 +1,4 @@ +import pdb import sys sys.path.append('../../..') @@ -6,17 +7,17 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config from torch.nn.parallel import DistributedDataParallel as DDP from convlab.nlg.nlg import NLG -from util import act2str -from scgpt_special_tokens import * +from convlab.nlg.scgpt.util import act2str class SCGPT(NLG): - def __init__(self, dataset_name, model_path, device='cpu'): + def __init__(self, dataset_name, model_path, device='gpu'): super(SCGPT, self).__init__() + self.dataset_name = dataset_name self.device = device self.model = GPT2LMHeadModel(config=GPT2Config.from_pretrained('gpt2-medium')).to(self.device) self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium') - self.model.load_state_dict(torch.load(model_path)) + self.model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device))) def generate(self, action): if isinstance(action, dict): @@ -50,5 +51,5 @@ class SCGPT(NLG): if self.tokenizer.eos_token in sent: sent = sent[:sent.index(self.tokenizer.eos_token)] return sent - output_strs = [clean_sentence(item) for item in outputs] + output_strs = [clean_sentence(self.tokenizer.decode(item, skip_special_tokens=True)) for item in outputs] return output_strs \ No newline at end of file