Skip to content
Snippets Groups Projects
Commit 0d233eee authored by Carel van Niekerk's avatar Carel van Niekerk :computer:
Browse files

Bug fix and update downloader for setsumbt

parent bd7a47d0
Branches
No related tags found
No related merge requests found
......@@ -20,6 +20,7 @@ import os
from shutil import copy2 as copy
import json
from copy import deepcopy
import pdb
import torch
import transformers
......@@ -34,6 +35,7 @@ from convlab.dst.setsumbt.modeling import training
from convlab.dst.setsumbt.dataset import ontology as embeddings
from convlab.dst.setsumbt.utils import get_args, update_args
from convlab.dst.setsumbt.modeling.ensemble_nbt import setup_ensemble
from convlab.util.custom_util import model_downloader
# Available model
......@@ -55,6 +57,23 @@ def main(args=None, config=None):
# Set up output directory
OUTPUT_DIR = args.output_dir
# Download model if needed
if not os.path.exists(OUTPUT_DIR):
# Get path /.../convlab/dst/setsumbt/multiwoz/models
download_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
download_path = os.path.join(download_path, 'models')
if not os.path.exists(download_path):
os.mkdir(download_path)
model_downloader(download_path, OUTPUT_DIR)
# Downloadable model path format http://.../model_name.zip
OUTPUT_DIR = OUTPUT_DIR.split('/')[-1].replace('.zip', '')
OUTPUT_DIR = os.path.join(download_path, OUTPUT_DIR)
args.tensorboard_path = os.path.join(OUTPUT_DIR, args.tensorboard_path.split('/')[-1])
args.logging_path = os.path.join(OUTPUT_DIR, args.logging_path.split('/')[-1])
os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders'))
if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)
os.mkdir(os.path.join(OUTPUT_DIR, 'database'))
......
......@@ -61,8 +61,8 @@ class SetSUMBTTracker(DST):
if not os.path.exists(download_path):
os.mkdir(download_path)
model_downloader(download_path, self.model_path)
# Downloadable model path format http://.../setsumbt_model_name.zip
self.model_path = self.model_path.split('/')[-1].split('_', 1)[-1].replace('.zip', '')
# Downloadable model path format http://.../model_name.zip
self.model_path = self.model_path.split('/')[-1].replace('.zip', '')
self.model_path = os.path.join(download_path, self.model_path)
# Select model type based on the encoder
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment