Skip to content
Snippets Groups Projects
Unverified Commit e7f924e9 authored by Carel van Niekerk's avatar Carel van Niekerk Committed by GitHub
Browse files

Scgpt generation fix (#128)


* Seperate test and train domains

* Add progress bars in ontology embedder

* Update custom_util.py

* Fix custom_util things I broke

* Github master

* Save dialogue ids in prediction file

* Fix bug in ontology enxtraction

* Return dialogue ids in predictions file and fix bugs

* Add setsumbt starting config loader

* Add script to extract golden labels from dataset to match model predictions

* Add more setsumbt configs

* Add option to use local files only in transformers package

* Update starting configurations for setsumbt

* Github master

* Update README.md

* Update README.md

* Update convlab/dialog_agent/agent.py

* Revert custom_util.py

* Update custom_util.py

* Commit unverified chnages :(:(:(:(

* Fix SetSUMBT bug resulting from new torch feature

* Setsumbt bug fixes

* Policy config refactor

* Policy config refactor

* small bug fix in memory with new config path

* Setsumbt info dict

* Fix generate function for SCGPT

* SCGPT default device GPU

Co-authored-by: default avatarCarel van Niekerk <carel.niekerk@hhu.de>
Co-authored-by: default avatarMichael Heck <michael.heck@hhu.de>
Co-authored-by: default avatarChristian Geishauser <christian.geishauser@hhu.de>
parent 06968a02
Branches
No related tags found
No related merge requests found
import pdb
import sys
sys.path.append('../../..')
......@@ -6,17 +7,17 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from torch.nn.parallel import DistributedDataParallel as DDP
from convlab.nlg.nlg import NLG
from util import act2str
from scgpt_special_tokens import *
from convlab.nlg.scgpt.util import act2str
class SCGPT(NLG):
def __init__(self, dataset_name, model_path, device='cpu'):
def __init__(self, dataset_name, model_path, device='gpu'):
super(SCGPT, self).__init__()
self.dataset_name = dataset_name
self.device = device
self.model = GPT2LMHeadModel(config=GPT2Config.from_pretrained('gpt2-medium')).to(self.device)
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
self.model.load_state_dict(torch.load(model_path))
self.model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device)))
def generate(self, action):
if isinstance(action, dict):
......@@ -50,5 +51,5 @@ class SCGPT(NLG):
if self.tokenizer.eos_token in sent:
sent = sent[:sent.index(self.tokenizer.eos_token)]
return sent
output_strs = [clean_sentence(item) for item in outputs]
output_strs = [clean_sentence(self.tokenizer.decode(item, skip_special_tokens=True)) for item in outputs]
return output_strs
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment