From e7f924e97dbeded5753b818c0629a4121261de08 Mon Sep 17 00:00:00 2001
From: Carel van Niekerk <40663106+carelvniekerk@users.noreply.github.com>
Date: Wed, 25 Jan 2023 15:01:07 +0100
Subject: [PATCH] Scgpt generation fix (#128)

* Seperate test and train domains

* Add progress bars in ontology embedder

* Update custom_util.py

* Fix custom_util things I broke

* Github master

* Save dialogue ids in prediction file

* Fix bug in ontology enxtraction

* Return dialogue ids in predictions file and fix bugs

* Add setsumbt starting config loader

* Add script to extract golden labels from dataset to match model predictions

* Add more setsumbt configs

* Add option to use local files only in transformers package

* Update starting configurations for setsumbt

* Github master

* Update README.md

* Update README.md

* Update convlab/dialog_agent/agent.py

* Revert custom_util.py

* Update custom_util.py

* Commit unverified chnages :(:(:(:(

* Fix SetSUMBT bug resulting from new torch feature

* Setsumbt bug fixes

* Policy config refactor

* Policy config refactor

* small bug fix in memory with new config path

* Setsumbt info dict

* Fix generate function for SCGPT

* SCGPT default device GPU

Co-authored-by: Carel van Niekerk <carel.niekerk@hhu.de>
Co-authored-by: Michael Heck <michael.heck@hhu.de>
Co-authored-by: Christian Geishauser <christian.geishauser@hhu.de>
---
 convlab/nlg/scgpt/scgpt.py | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/convlab/nlg/scgpt/scgpt.py b/convlab/nlg/scgpt/scgpt.py
index ee591e79..def3b2f3 100644
--- a/convlab/nlg/scgpt/scgpt.py
+++ b/convlab/nlg/scgpt/scgpt.py
@@ -1,3 +1,4 @@
+import pdb
 import sys
 sys.path.append('../../..')
 
@@ -6,17 +7,17 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
 from torch.nn.parallel import DistributedDataParallel as DDP
 
 from convlab.nlg.nlg import NLG
-from util import act2str
-from scgpt_special_tokens import *
+from convlab.nlg.scgpt.util import act2str
 
 
 class SCGPT(NLG):
-    def __init__(self, dataset_name, model_path, device='cpu'):
+    def __init__(self, dataset_name, model_path, device='gpu'):
         super(SCGPT, self).__init__()
+        self.dataset_name = dataset_name
         self.device = device
         self.model = GPT2LMHeadModel(config=GPT2Config.from_pretrained('gpt2-medium')).to(self.device)
         self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
-        self.model.load_state_dict(torch.load(model_path))
+        self.model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device)))
 
     def generate(self, action):
         if isinstance(action, dict):
@@ -50,5 +51,5 @@ class SCGPT(NLG):
                 if self.tokenizer.eos_token in sent:
                     sent = sent[:sent.index(self.tokenizer.eos_token)]
                 return sent
-            output_strs = [clean_sentence(item) for item in outputs]
+            output_strs = [clean_sentence(self.tokenizer.decode(item, skip_special_tokens=True)) for item in outputs]
             return output_strs
\ No newline at end of file
-- 
GitLab