Skip to content
Snippets Groups Projects
Unverified Commit e7f924e9 authored by Carel van Niekerk's avatar Carel van Niekerk Committed by GitHub
Browse files

Scgpt generation fix (#128)


* Seperate test and train domains

* Add progress bars in ontology embedder

* Update custom_util.py

* Fix custom_util things I broke

* Github master

* Save dialogue ids in prediction file

* Fix bug in ontology enxtraction

* Return dialogue ids in predictions file and fix bugs

* Add setsumbt starting config loader

* Add script to extract golden labels from dataset to match model predictions

* Add more setsumbt configs

* Add option to use local files only in transformers package

* Update starting configurations for setsumbt

* Github master

* Update README.md

* Update README.md

* Update convlab/dialog_agent/agent.py

* Revert custom_util.py

* Update custom_util.py

* Commit unverified chnages :(:(:(:(

* Fix SetSUMBT bug resulting from new torch feature

* Setsumbt bug fixes

* Policy config refactor

* Policy config refactor

* small bug fix in memory with new config path

* Setsumbt info dict

* Fix generate function for SCGPT

* SCGPT default device GPU

Co-authored-by: default avatarCarel van Niekerk <carel.niekerk@hhu.de>
Co-authored-by: default avatarMichael Heck <michael.heck@hhu.de>
Co-authored-by: default avatarChristian Geishauser <christian.geishauser@hhu.de>
parent 06968a02
No related branches found
No related tags found
No related merge requests found
import pdb
import sys import sys
sys.path.append('../../..') sys.path.append('../../..')
...@@ -6,17 +7,17 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config ...@@ -6,17 +7,17 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from convlab.nlg.nlg import NLG from convlab.nlg.nlg import NLG
from util import act2str from convlab.nlg.scgpt.util import act2str
from scgpt_special_tokens import *
class SCGPT(NLG): class SCGPT(NLG):
def __init__(self, dataset_name, model_path, device='cpu'): def __init__(self, dataset_name, model_path, device='gpu'):
super(SCGPT, self).__init__() super(SCGPT, self).__init__()
self.dataset_name = dataset_name
self.device = device self.device = device
self.model = GPT2LMHeadModel(config=GPT2Config.from_pretrained('gpt2-medium')).to(self.device) self.model = GPT2LMHeadModel(config=GPT2Config.from_pretrained('gpt2-medium')).to(self.device)
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium') self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
self.model.load_state_dict(torch.load(model_path)) self.model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device)))
def generate(self, action): def generate(self, action):
if isinstance(action, dict): if isinstance(action, dict):
...@@ -50,5 +51,5 @@ class SCGPT(NLG): ...@@ -50,5 +51,5 @@ class SCGPT(NLG):
if self.tokenizer.eos_token in sent: if self.tokenizer.eos_token in sent:
sent = sent[:sent.index(self.tokenizer.eos_token)] sent = sent[:sent.index(self.tokenizer.eos_token)]
return sent return sent
output_strs = [clean_sentence(item) for item in outputs] output_strs = [clean_sentence(self.tokenizer.decode(item, skip_special_tokens=True)) for item in outputs]
return output_strs return output_strs
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment