From 726a632684abc04b35d6bfe75260788ae2268529 Mon Sep 17 00:00:00 2001 From: Carel van Niekerk <40663106+carelvniekerk@users.noreply.github.com> Date: Wed, 8 Dec 2021 19:42:28 +0100 Subject: [PATCH] Update scgpt.py --- convlab2/nlg/scgpt/multiwoz/scgpt.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/convlab2/nlg/scgpt/multiwoz/scgpt.py b/convlab2/nlg/scgpt/multiwoz/scgpt.py index b5347aa6..5c933cad 100644 --- a/convlab2/nlg/scgpt/multiwoz/scgpt.py +++ b/convlab2/nlg/scgpt/multiwoz/scgpt.py @@ -26,12 +26,15 @@ class SCGPT(NLG): model_dir = os.path.dirname(os.path.abspath(__file__)) if not os.path.isfile(model_file): model_file = cached_path(model_file) + if not os.path.isdir(model_file): archive = zipfile.ZipFile(model_file, 'r') archive.extractall(model_dir) - # Get model directory - model_file = archive.filelist[0].filename.replace('/', '') - - self.model_name_or_path = os.path.join(model_dir, model_file) + # Get model directory + model_file = archive.filelist[0].filename.replace('/', '') + self.model_name_or_path = os.path.join(model_dir, model_file) + else: + self.model_name_or_path = model_file + self.length = 50 self.num_samples = 5 self.temperature = 1.0 @@ -102,4 +105,4 @@ class SCGPT(NLG): text = text.split('& ')[-1] text = text[: text.find(self.stop_token) if self.stop_token else None] - return text \ No newline at end of file + return text -- GitLab