Skip to content
Snippets Groups Projects
Commit a2d86eff authored by Carel van Niekerk's avatar Carel van Niekerk
Browse files

Update model loader

parent 7a7381cb
No related branches found
No related tags found
No related merge requests found
...@@ -10,23 +10,28 @@ from convlab2.nlg.nlg import NLG ...@@ -10,23 +10,28 @@ from convlab2.nlg.nlg import NLG
from convlab2.util.file_util import cached_path from convlab2.util.file_util import cached_path
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
DEFAULT_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "nlg-gpt-multiwoz.zip")
class SCGPT(NLG): class SCGPT(NLG):
def __init__(self, def __init__(self, model_file=None,
archive_file=DEFAULT_ARCHIVE_FILE, use_cuda=True, is_user=False):
use_cuda=True, # If no filename is mentioned then set to default
is_user=False, if not model_file:
model_file='https://convlab.blob.core.windows.net/convlab-2/nlg-gpt-multiwoz.zip'): if is_user:
model_file = 'https://convlab.blob.core.windows.net/convlab-2/nlg-gpt-multiwoz.zip'
else:
model_file = 'https://zenodo.org/record/5767426/files/neo_scgpt_system.zip'
# Load from file/url
model_dir = os.path.dirname(os.path.abspath(__file__)) model_dir = os.path.dirname(os.path.abspath(__file__))
if not os.path.isfile(archive_file): if not os.path.isfile(model_file):
archive_file = cached_path(model_file) model_file = cached_path(model_file)
archive = zipfile.ZipFile(archive_file, 'r') archive = zipfile.ZipFile(model_file, 'r')
archive.extractall(model_dir) archive.extractall(model_dir)
# Get model directory
model_file = archive.filelist[0].filename.replace('/', '')
self.model_name_or_path = os.path.join(model_dir, 'multiwoz') self.model_name_or_path = os.path.join(model_dir, model_file)
self.length = 50 self.length = 50
self.num_samples = 5 self.num_samples = 5
self.temperature = 1.0 self.temperature = 1.0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment