From e7f924e97dbeded5753b818c0629a4121261de08 Mon Sep 17 00:00:00 2001 From: Carel van Niekerk <40663106+carelvniekerk@users.noreply.github.com> Date: Wed, 25 Jan 2023 15:01:07 +0100 Subject: [PATCH] Scgpt generation fix (#128) * Seperate test and train domains * Add progress bars in ontology embedder * Update custom_util.py * Fix custom_util things I broke * Github master * Save dialogue ids in prediction file * Fix bug in ontology enxtraction * Return dialogue ids in predictions file and fix bugs * Add setsumbt starting config loader * Add script to extract golden labels from dataset to match model predictions * Add more setsumbt configs * Add option to use local files only in transformers package * Update starting configurations for setsumbt * Github master * Update README.md * Update README.md * Update convlab/dialog_agent/agent.py * Revert custom_util.py * Update custom_util.py * Commit unverified chnages :(:(:(:( * Fix SetSUMBT bug resulting from new torch feature * Setsumbt bug fixes * Policy config refactor * Policy config refactor * small bug fix in memory with new config path * Setsumbt info dict * Fix generate function for SCGPT * SCGPT default device GPU Co-authored-by: Carel van Niekerk <carel.niekerk@hhu.de> Co-authored-by: Michael Heck <michael.heck@hhu.de> Co-authored-by: Christian Geishauser <christian.geishauser@hhu.de> --- convlab/nlg/scgpt/scgpt.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/convlab/nlg/scgpt/scgpt.py b/convlab/nlg/scgpt/scgpt.py index ee591e79..def3b2f3 100644 --- a/convlab/nlg/scgpt/scgpt.py +++ b/convlab/nlg/scgpt/scgpt.py @@ -1,3 +1,4 @@ +import pdb import sys sys.path.append('../../..') @@ -6,17 +7,17 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config from torch.nn.parallel import DistributedDataParallel as DDP from convlab.nlg.nlg import NLG -from util import act2str -from scgpt_special_tokens import * +from convlab.nlg.scgpt.util import act2str class SCGPT(NLG): - def __init__(self, dataset_name, model_path, device='cpu'): + def __init__(self, dataset_name, model_path, device='gpu'): super(SCGPT, self).__init__() + self.dataset_name = dataset_name self.device = device self.model = GPT2LMHeadModel(config=GPT2Config.from_pretrained('gpt2-medium')).to(self.device) self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium') - self.model.load_state_dict(torch.load(model_path)) + self.model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device))) def generate(self, action): if isinstance(action, dict): @@ -50,5 +51,5 @@ class SCGPT(NLG): if self.tokenizer.eos_token in sent: sent = sent[:sent.index(self.tokenizer.eos_token)] return sent - output_strs = [clean_sentence(item) for item in outputs] + output_strs = [clean_sentence(self.tokenizer.decode(item, skip_special_tokens=True)) for item in outputs] return output_strs \ No newline at end of file -- GitLab