diff --git a/convlab2/nlg/scgpt/multiwoz/scgpt.py b/convlab2/nlg/scgpt/multiwoz/scgpt.py index b5347aa6af0ffc8443852f8ba9e6b39f820791bc..5c933cad0d3ef7aa71335cdf1cab65ea7b4cd795 100644 --- a/convlab2/nlg/scgpt/multiwoz/scgpt.py +++ b/convlab2/nlg/scgpt/multiwoz/scgpt.py @@ -26,12 +26,15 @@ class SCGPT(NLG): model_dir = os.path.dirname(os.path.abspath(__file__)) if not os.path.isfile(model_file): model_file = cached_path(model_file) + if not os.path.isdir(model_file): archive = zipfile.ZipFile(model_file, 'r') archive.extractall(model_dir) - # Get model directory - model_file = archive.filelist[0].filename.replace('/', '') - - self.model_name_or_path = os.path.join(model_dir, model_file) + # Get model directory + model_file = archive.filelist[0].filename.replace('/', '') + self.model_name_or_path = os.path.join(model_dir, model_file) + else: + self.model_name_or_path = model_file + self.length = 50 self.num_samples = 5 self.temperature = 1.0 @@ -102,4 +105,4 @@ class SCGPT(NLG): text = text.split('& ')[-1] text = text[: text.find(self.stop_token) if self.stop_token else None] - return text \ No newline at end of file + return text