diff --git a/convlab/dst/setsumbt/do/nbt.py b/convlab/dst/setsumbt/do/nbt.py index ea099442ddd18d0cd36a79db13b1f47788eb4fd4..d250f29e0186732b27fb284273d9f9ff4f166d2f 100644 --- a/convlab/dst/setsumbt/do/nbt.py +++ b/convlab/dst/setsumbt/do/nbt.py @@ -20,6 +20,7 @@ import os from shutil import copy2 as copy import json from copy import deepcopy +import pdb import torch import transformers @@ -34,6 +35,7 @@ from convlab.dst.setsumbt.modeling import training from convlab.dst.setsumbt.dataset import ontology as embeddings from convlab.dst.setsumbt.utils import get_args, update_args from convlab.dst.setsumbt.modeling.ensemble_nbt import setup_ensemble +from convlab.util.custom_util import model_downloader # Available model @@ -55,6 +57,23 @@ def main(args=None, config=None): # Set up output directory OUTPUT_DIR = args.output_dir + + # Download model if needed + if not os.path.exists(OUTPUT_DIR): + # Get path /.../convlab/dst/setsumbt/multiwoz/models + download_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + download_path = os.path.join(download_path, 'models') + if not os.path.exists(download_path): + os.mkdir(download_path) + model_downloader(download_path, OUTPUT_DIR) + # Downloadable model path format http://.../model_name.zip + OUTPUT_DIR = OUTPUT_DIR.split('/')[-1].replace('.zip', '') + OUTPUT_DIR = os.path.join(download_path, OUTPUT_DIR) + + args.tensorboard_path = os.path.join(OUTPUT_DIR, args.tensorboard_path.split('/')[-1]) + args.logging_path = os.path.join(OUTPUT_DIR, args.logging_path.split('/')[-1]) + os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders')) + if not os.path.exists(OUTPUT_DIR): os.makedirs(OUTPUT_DIR) os.mkdir(os.path.join(OUTPUT_DIR, 'database')) diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py index 6b620247fd4a36223fbed8c46c54615f7c69da98..eca7f1749369f9569d6b923312a93cd317e0701c 100644 --- a/convlab/dst/setsumbt/tracker.py +++ b/convlab/dst/setsumbt/tracker.py @@ -61,8 +61,8 @@ class SetSUMBTTracker(DST): if not os.path.exists(download_path): os.mkdir(download_path) model_downloader(download_path, self.model_path) - # Downloadable model path format http://.../setsumbt_model_name.zip - self.model_path = self.model_path.split('/')[-1].split('_', 1)[-1].replace('.zip', '') + # Downloadable model path format http://.../model_name.zip + self.model_path = self.model_path.split('/')[-1].replace('.zip', '') self.model_path = os.path.join(download_path, self.model_path) # Select model type based on the encoder